In [1]:
import os.path
import zipfile

import numpy as np

from skimage import io as skimageio
from skimage import color as skimagecolor


In [2]:
# hyperparameter
batch_size = 128

In [3]:
# Env check
# check dataset file exists
dataset_path = 'garbage_classification.zip'
if not os.path.isfile(dataset_path):
    print("The dataset is not found. Please put it in the root of the directory and rename it as \"garbage_classification.zip\".")
    print("The link of the dataset is: https://drive.google.com/file/d/1kcwBy_yG47Mp2iyq6Oo6ACojgdXRt4Bs/view?usp=sharing")
    exit()

# extract the dataset
if not os.path.exists("garbage_classification/"):
    print("extracting the dataset... ")
    with zipfile.ZipFile(dataset_path, 'r') as zip_ref:
        zip_ref.extractall("./")
print("The dataset has been extracted.")

The dataset has been extracted.


In [4]:
# define dataloader
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset

categories = {
    "battery": 0,
    "biological": 1,
    "brown-glass": 2,
    "cardboard": 3,
    "clothes": 4,
    "green-glass": 5,
    "metal": 6,
    "paper": 7,
    "plastic": 8,
    "shoes": 9,
    "trash": 10,
    "white-glass": 11,
}

data_transforms = {
    "train": transforms.Compose(
        [
            # data augmentation
            transforms.RandomRotation(45),  # random rotation from -45 to 45 degrees
            transforms.RandomHorizontalFlip(
                p=0.5
            ),  # random rotation from probability p
            transforms.RandomVerticalFlip(
                p=0.5
            ),  # random vertical flip from probability p
            transforms.ColorJitter(
                brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1
            ),
            transforms.RandomGrayscale(
                p=0.025
            ),  # randomly transform color to grayscae, R=G=B for 3 channels
            # ritual transformation
            transforms.Resize(256),
            transforms.CenterCrop(224),  # center crop
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "valid": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}


class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, "r")
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split("/")
            imgs.append((line, categories[words[1]]))
            print(categories[words[1]])
            self.imgs = imgs
            self.transform = transform
            self.target_transform = target_transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        fn, label = self.imgs[
            index
        ]  # self.imgs是一个list，self.imgs的一个元素是一个str，包含图片路径，图片标签，这些信息是在init函数中从txt文件中读取的
        # fn是一个图片路径
        img = Image.open(fn).convert(
            "RGB"
        )  # 利用Image.open对图片进行读取，img类型为 Image ，mode=‘RGB’
        if self.transform is not None:
            img = self.transform(img)
        return img, label


# Load dataset
train_path_file = "garbage_classification/train.txt"
test_path_file = "garbage_classification/test.txt"
val_path_file = "garbage_classification/val.txt"
train_data = MyDataset(
    txt_path=train_path_file, transform=data_transforms["train"]
)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_data = MyDataset(
    txt_path=val_path_file, transform=data_transforms["valid"]
)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size)

# show the first 10 data
# for epoch in range(1):
#     for step, (path, label) in enumerate(train_loader):
#         print('Epoch: ', epoch, '| Step: ', step, '| path: ',
#               path, '| label: ', label)

  from .autonotebook import tqdm as notebook_tqdm


0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


In [5]:
# Load dataset
train_path_file = "garbage_classification/train.txt"
test_path_file = "garbage_classification/test.txt"
val_path_file = "garbage_classification/val.txt"
# read path file
train_file_paths = []
test_file_paths = []
val_file_paths = []
with open(train_path_file) as file:
    train_file_paths = [line.rstrip() for line in file]
with open(test_path_file) as file:
    test_file_paths = [line.rstrip() for line in file]
with open(val_path_file) as file:
    val_file_paths = [line.rstrip() for line in file]
print("The size of train data: ",len(train_file_paths))
print("The size of test data: ",len(test_file_paths))
print("The size of val data: ",len(val_file_paths))

# read dataset
w_min = 100000
w_max = -1
h_min = 100000
h_max = -1
# for path in train_file_paths:
#     im = skimageio.imread(path)
#     if h_min > im.shape[0]:
#         h_min = im.shape[0]
#     if w_min > im.shape[1]:
#         w_min = im.shape[1]
#     if h_max < im.shape[0]:
#         h_max = im.shape[0]
#     if w_max < im.shape[1]:
#         w_max = im.shape[1]

print(w_min,h_min, w_max, h_max)
# print(im.shape) # numpy矩阵，(h,w,c)
# print(im.dtype)
# print(im.size)

# im3 = skimagecolor.rgb2grey(im)# 灰度化处理
# skimageio.imshow(im3)




The size of train data:  9923
The size of test data:  3106
The size of val data:  2486
100000 100000 -1 -1


In [6]:
torch.cuda.is_available()
from tqdm import tqdm

In [7]:
# 防止SSL报错
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# define the model
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from einops.layers.torch import Rearrange
from tqdm import tqdm

# 定义一个基本的Transformer块
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, dim),
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attended, _ = self.attention(x, x, x)
        x = x + self.dropout(attended)
        x = self.norm1(x)
        mlp_output = self.mlp(x)
        x = x + self.dropout(mlp_output)
        x = self.norm2(x)
        return x

# 定义Vision Transformer模型
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, num_heads, num_layers, mlp_dim, dropout=0.1):
        super(VisionTransformer, self).__init__()
        assert image_size % patch_size == 0, "image_size must be divisible by patch_size"
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2  # 3 channels for RGB images

        self.embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.transformer = nn.Sequential(
            *[TransformerBlock(dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)]
        )

        self.classification_head = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(1, 0, 2)  # Transpose
        x = self.transformer(x)
        x = x.mean(dim=0)  # Global average pooling
        x = self.classification_head(x)
        return x

# 数据加载和预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 定义模型和优化器
image_size = 224
patch_size = 16
num_classes = 12
model = VisionTransformer(image_size, patch_size, num_classes, dim=256, num_heads=8, num_layers=12, mlp_dim=512, dropout=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 训练模型
epochs = 12
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    loop = tqdm(train_loader)
    for batch_idx, (data, target) in enumerate(loop):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # add stuff to progress bar in the end
        loop.set_description(f"Epoch [{epoch}/{epochs}]")
        loop.set_postfix(loss=torch.rand(1).item(), acc=torch.rand(1).item())
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader)}")

# 保存模型
torch.save(model.state_dict(), 'vision_transformer_model.pth')


Epoch [0/1]: 100%|██████████| 78/78 [08:26<00:00,  6.49s/it, acc=0.494, loss=0.233]  

Epoch 1/1, Loss: 2.679076576432309





In [8]:
# test the model
model.eval()
# eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)with torch.no_grad():
correct = 0
total = 0
for images, labels in valid_loader:
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(
    "Test Accuracy of the model on the 10000 test images: {} %".format(
        100 * correct / total
    )
)

Test Accuracy of the model on the 10000 test images: 12.751407884151247 %
