In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
from PIL import Image
import numpy as np
import os, time
import random
random.seed(0)


readvdnames = lambda x: open(x).read().rstrip().split('\n')

################################# DEFINE DATASET #################################
class TinySegData(Dataset):
    def __init__(self, db_root="TinySeg", img_size=256, phase='train'):
        classes = ['person', 'bird', 'car', 'cat', 'plane', ]
        seg_ids = [1, 2, 3, 4, 5]

        templ_image = db_root + "/JPEGImages/{}.jpg"
        templ_mask = db_root + "/Annotations/{}.png"

        ids = readvdnames(db_root + "/ImageSets/" + phase + ".txt")

        # build training and testing dbs
        samples = []
        for i in ids:
            samples.append([templ_image.format(i), templ_mask.format(i)])
        self.samples = samples
        self.phase = phase
        self.db_root = db_root
        self.img_size = img_size

        self.color_transform = torchvision.transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2)

        if not self.phase == 'train':
            print ("resize and augmentation will not be applied...")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if self.phase == 'train':
            return self.get_train_item(idx)
        else:
            return self.get_train_item(idx)

    def get_train_item(self, idx):
        sample = self.samples[idx]
        image = Image.open(sample[0])

        if random.randint(0, 1) > 0:
            image = self.color_transform(image)
        image = np.asarray(image)[..., ::-1]     # to BGR
        seg_gt = (np.asarray(Image.open(sample[1]).convert('P'))).astype(np.uint8)

        image = image.astype(np.float32)
        image = image / 127.5 - 1        # -1~1

        if random.randint(0, 1) > 0:
            image = image[:, ::-1, :]       # HWC
            seg_gt = seg_gt[:, ::-1]

        # random crop to 256x256
        height, width = image.shape[0], image.shape[1]
        if height == width:
            miny, maxy = 0, 256
            minx, maxx = 0, 256
        elif height > width:
            miny = np.random.randint(0, height-256)
            maxy = miny+256
            minx = 0
            maxx = 256
        else:
            miny = 0
            maxy = 256
            minx = np.random.randint(0, width-256)
            maxx = minx+256
        image = image[miny:maxy, minx:maxx, :].copy()
        seg_gt = seg_gt[miny:maxy, minx:maxx].copy()

        if self.img_size != 256:
            new_size = (self.img_size, self.img_size)
            image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
            seg_gt = cv2.resize(seg_gt, new_size, interpolation=cv2.INTER_NEAREST)

        image = np.transpose(image, (2, 0, 1))      # To CHW

        # cv2.imwrite("test.png", np.concatenate([(image[0]+1)*127.5, seg_gt*255], axis=0))
        return image, seg_gt, sample

    def get_test_item(self, idx):
        sample = self.samples[idx]
        image = cv2.imread(sample[0])
        seg_gt = (np.asarray(Image.open(sample[1]).convert('P'))).astype(np.uint8)

        image = image.astype(np.float32)
        image = image / 127.5 - 1        # -1~1
        image = np.transpose(image, (2, 0, 1))

        # cv2.imwrite("test.png", np.concatenate([(image[0]+1)*127.5, seg_gt*255], axis=0))
        return image, seg_gt, sample

################################# FUNCTIONS #################################
def get_confusion_matrix(gt_label, pred_label, class_num):
        """
        Calcute the confusion matrix by given label and pred
        :param gt_label: the ground truth label
        :param pred_label: the pred label
        :param class_num: the number of class
        :return: the confusion matrix
        """
        index = (gt_label * class_num + pred_label).astype('int32')

        label_count = np.bincount(index)
        confusion_matrix = np.zeros((class_num, class_num))

        for i_label in range(class_num):
            for i_pred_label in range(class_num):
                cur_index = i_label * class_num + i_pred_label
                if cur_index < len(label_count):
                    confusion_matrix[i_label, i_pred_label] = label_count[cur_index]

        return confusion_matrix

def get_confusion_matrix_for_3d(gt_label, pred_label, class_num):
    confusion_matrix = np.zeros((class_num, class_num))

    for sub_gt_label, sub_pred_label in zip(gt_label, pred_label):
        sub_gt_label = sub_gt_label[sub_gt_label != 255]
        sub_pred_label = sub_pred_label[sub_pred_label != 255]
        cm = get_confusion_matrix(sub_gt_label, sub_pred_label, class_num)
        confusion_matrix += cm
    return confusion_matrix



In [2]:
class Attention(nn.Module):  
    def __init__(self, embed_size = 3*16*16 , heads = 1, dropout=0.2):  
        super(Attention, self).__init__()  
        self.embed_size = embed_size  
        self.heads = heads  
        self.head_dim = embed_size // heads  

        assert (  
            self.head_dim * heads == embed_size  
        ), "Embedding size must be divisible by heads"  

        self.values = nn.Linear(embed_size, embed_size, bias=False)  
        self.keys = nn.Linear(embed_size, embed_size, bias=False)  
        self.queries = nn.Linear(embed_size, embed_size, bias=False)  
        self.fc_out = nn.Linear(embed_size, embed_size)  

        self.dropout = nn.Dropout(dropout)  

    

    def forward(self, x):    

        values = self.values(x)  # (batch_size, seq_length, embed_size)  
        keys = self.keys(x)      
        queries = self.queries(x)  

       
        # Scaled dot-product attention  
        
        energy = torch.bmm(queries, keys.transpose(-2, -1))  # (batch_size, query_length, key_length)  
        

        # 进行缩放  
        scaling = self.head_dim ** 0.5  
        scaled_energy = energy / scaling  # (batch_size, query_length, key_length)  

        # 应用 softmax  
        attention = nn.functional.softmax(scaled_energy, dim=-1)

        # 计算输出  
        out = torch.bmm(attention, values)  # (batch_size, query_length, head_dim)  
        # 注意：这里的 values 仍然是 (batch_size, heads, seq_length, head_dim)  

    

        # 最后，通过线性层输出  
        out = self.fc_out(out)  
        return self.dropout(out)  
    
class VisionEncoder(nn.Module):  
    def __init__(self, embed_size = 3*16*16, heads = 1, drop_rate=0.2):  
        super(VisionEncoder, self).__init__()  
        self.attention = Attention(embed_size, heads, dropout=drop_rate)  
        self.norm1 = nn.LayerNorm(embed_size)  
        self.norm2 = nn.LayerNorm(embed_size)  
        self.mlp = nn.Sequential(  
            nn.Linear(embed_size, 2048),  
            nn.GELU(),  
            nn.Dropout(drop_rate),  
            nn.Linear(2048, embed_size),  
            nn.Dropout(drop_rate)  
        )  

    def forward(self, x):  
        attention = self.attention(x)  
        x = self.norm1(attention + x)  # Residual Connection  
        mlp_out = self.mlp(x)  
        x = self.norm2(mlp_out + x)  # Residual Connection  
        return x  

class VisionTransformer(nn.Module):  
    def __init__(self, num_classes = 20, embed_size=3*16*16, num_layers=3, heads=1, num_patches=14*14, drop_rate=0.2):  
        super(VisionTransformer, self).__init__()  
        self.patch = nn.Conv2d(3,embed_size,16,16)
        self.encoders = nn.ModuleList(  
            [VisionEncoder(embed_size, heads, drop_rate) for _ in range(num_layers)]  
        )  
        self.norm = nn.LayerNorm(embed_size)  
        self.classifier = nn.Linear(embed_size, num_classes)  
        self.dropout = nn.Dropout(drop_rate)  

    def forward(self, x):  
        x = self.patch(x)
        x = x.flatten(2).permute(0,2,1)
        
        for encoder in self.encoders:  
            x = encoder(x)  
        
        x = self.norm(x)  
        x = self.dropout(x)  
        x = x.mean(dim=1)  # Global average pooling  
        return self.classifier(x)  


In [3]:
dataset = TinySegData(img_size=224, phase='train')
test = TinySegData(img_size=224,phase='val')

resize and augmentation will not be applied...


In [4]:
train_loader = DataLoader(dataset, batch_size=64)
test_loader = DataLoader(test,batch_size=32)
len(train_loader)

94

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  


In [7]:
example = VisionTransformer().to(device)

lr = 0.0001
epoches = 40
criterion = nn.BCEWithLogitsLoss().to(device)
optimizer = torch.optim.AdamW(example.parameters(),lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)  
loss_list = []
accuracy_list = []
loss_list1 = []
accuracy_list1 = []

for epoch in range(200):
    train_loss = 0
    train_acc = 0

    for img,label,_ in train_loader:
        label = torch.max(label.flatten(1),dim=1)[0].long()
        label = nn.functional.one_hot(label,20).float()
        img,label = img.to(device),label.to(device)
        output = example(img)
        loss = criterion(output,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _,pred = output.max(1)
        _,result = label.max(1)
        num_correct = (pred==result).sum().item()
        acc = num_correct / img.shape[0]
        train_acc += acc
    scheduler.step()  

    loss_list.append(train_loss/len(train_loader))
    accuracy_list.append(train_acc/len(train_loader))
    print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}'.format(epoch+1, train_loss/len(train_loader), train_acc/len(train_loader)))


epoch: 1, Train Loss: 0.128710, Train Acc: 0.506926
epoch: 2, Train Loss: 0.102692, Train Acc: 0.530197
epoch: 3, Train Loss: 0.097559, Train Acc: 0.555519
epoch: 4, Train Loss: 0.093686, Train Acc: 0.572695
epoch: 5, Train Loss: 0.091486, Train Acc: 0.583112
epoch: 6, Train Loss: 0.090196, Train Acc: 0.600399
epoch: 7, Train Loss: 0.087856, Train Acc: 0.604998
epoch: 8, Train Loss: 0.087754, Train Acc: 0.611536
epoch: 9, Train Loss: 0.086006, Train Acc: 0.614251
epoch: 10, Train Loss: 0.085187, Train Acc: 0.626496
epoch: 11, Train Loss: 0.084372, Train Acc: 0.627770
epoch: 12, Train Loss: 0.084548, Train Acc: 0.630762
epoch: 13, Train Loss: 0.083482, Train Acc: 0.625388
epoch: 14, Train Loss: 0.083114, Train Acc: 0.626441
epoch: 15, Train Loss: 0.081754, Train Acc: 0.637467
epoch: 16, Train Loss: 0.081868, Train Acc: 0.643562
epoch: 17, Train Loss: 0.080822, Train Acc: 0.642841
epoch: 18, Train Loss: 0.081486, Train Acc: 0.640403
epoch: 19, Train Loss: 0.080384, Train Acc: 0.647717
ep

In [8]:
test_acc = 0
for img,label,_ in test_loader:
        label = torch.max(label.flatten(1),dim=1)[0].long()
        label = nn.functional.one_hot(label,20).float()
        img,label = img.to(device),label.to(device)
        output = example(img)



        _,pred = output.max(1)
        _,result = label.max(1)
        num_correct = (pred==result).sum().item()
        acc = num_correct / img.shape[0]
        test_acc += acc
print(test_acc/len(test_loader))

0.6703125
