# swin_transformer
<img src="img/swin_transformer.png" width="100%">  

In [1]:
data_path = '../../dataset/bird_datasets/train'
classes_path = '../../dataset/bird_datasets/classes.txt'
training_labels_path = '../../dataset/bird_datasets/training_labels.txt'

BATCH_SIZE = 8
WORKERS = 16
epochs = 100
learning_rate = 2e-4
weight_decay = 1e-4
momentum = 0.9
label_smooth=0.2
# pretrain = None
pretrain = 'model/model_bird_vic_simsiam_pretrain/checkpoint.pth.tar'

output_foloder = 'model/model_bird_vit_vic_TripletLoss'

In [2]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/' + output_foloder, comment=f' batch_size={BATCH_SIZE} lr={learning_rate}')
writer.add_text('Remark', 'batch_size = {}'.format(BATCH_SIZE) , 0)
writer.add_text('Remark', 'learning_rate = {}'.format(learning_rate) , 0)
writer.add_text('Remark', 'momentum = {}'.format(momentum) , 0)
writer.add_text('Remark', 'weight_decay = {}'.format(weight_decay) , 0)
writer.add_text('Remark', 'output_foloder = {}'.format(output_foloder), 0)
writer.add_text('Remark', 'pretrain = {}'.format(pretrain), 0)
writer.add_text('Remark', 'label_smooth = {}'.format(label_smooth), 0)
writer.add_text('Remark', 'TripletLoss', 0)

writer.flush()

# GPU Check

In [3]:
import os
import torch
import numpy as np
import math
import glob
from os import listdir
from os import walk
from torch import nn
from tqdm import tqdm 
from torch import optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR, StepLR, ReduceLROnPlateau
from torchvision import datasets, transforms
# from swin_transformer_pytorch import SwinTransformer
import torchvision.models as models
from loss_functions.CrossEntropyLS import CrossEntropyLS
from loss_functions.triplet_loss import TripletLoss

import PIL.Image as Image
from matplotlib import pyplot as plt
import torch.nn.functional as F
import timm
from torch.cuda.amp import GradScaler, autocast

In [4]:
print('torch version:' + torch.__version__)

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Available GPUs: ', end='')
    for i in range(torch.cuda.device_count()):
        print(torch.cuda.get_device_name(i), end=' ')
else:
    device = torch.device('cpu')
    print('CUDA is not available.')

torch version:1.9.0+cu102
Available GPUs: GeForce RTX 2080 Ti GeForce GTX 1080 Ti 

# Dataset

#### Define dataset, and dataloader

In [5]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir, data_list):
    images = []
    for img_name, idx, labels in data_list:
        item = (img_name, int(idx))
        images.append(item)
    return images

class BirdImageLoader(Dataset):
    def __init__(self, root, data_list, class_to_idx, transform=None, target_transform=None):
        imgs = make_dataset(root, data_list)

        self.root = root
        self.imgs = imgs
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self):
        return len(self.imgs)

# Data augmentation

In [6]:
from PIL import ImageFilter
import random

class GaussianBlur(object):
    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x
    
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# https://github.com/aniket03/self_supervised_bird_classification/blob/master/dataset_helpers.py
def all_in_aug():
    all_in_transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=[0.5, 1.5]),
    transforms.RandomRotation(degrees=15),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    return all_in_transform

def get_aug_trnsform():
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    return transform

def get_trnsform():
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(1., 1.)),
        transforms.ToTensor(),
        normalize
    ])
    return transform
trans_aug = get_aug_trnsform()
trans = get_trnsform()

## read classes txt

In [7]:
class_to_idx = {}
with open(classes_path) as f:
    for line in f.readlines():
        label_num =  line.split(".")[0] 
        label_str =  line.split(".")[1][:-1]
        class_to_idx[int(label_num) - 1] = label_str
print(class_to_idx)

{0: 'Black_footed_Albatross', 1: 'Laysan_Albatross', 2: 'Sooty_Albatross', 3: 'Groove_billed_Ani', 4: 'Crested_Auklet', 5: 'Least_Auklet', 6: 'Parakeet_Auklet', 7: 'Rhinoceros_Auklet', 8: 'Brewer_Blackbird', 9: 'Red_winged_Blackbird', 10: 'Rusty_Blackbird', 11: 'Yellow_headed_Blackbird', 12: 'Bobolink', 13: 'Indigo_Bunting', 14: 'Lazuli_Bunting', 15: 'Painted_Bunting', 16: 'Cardinal', 17: 'Spotted_Catbird', 18: 'Gray_Catbird', 19: 'Yellow_breasted_Chat', 20: 'Eastern_Towhee', 21: 'Chuck_will_Widow', 22: 'Brandt_Cormorant', 23: 'Red_faced_Cormorant', 24: 'Pelagic_Cormorant', 25: 'Bronzed_Cowbird', 26: 'Shiny_Cowbird', 27: 'Brown_Creeper', 28: 'American_Crow', 29: 'Fish_Crow', 30: 'Black_billed_Cuckoo', 31: 'Mangrove_Cuckoo', 32: 'Yellow_billed_Cuckoo', 33: 'Gray_crowned_Rosy_Finch', 34: 'Purple_Finch', 35: 'Northern_Flicker', 36: 'Acadian_Flycatcher', 37: 'Great_Crested_Flycatcher', 38: 'Least_Flycatcher', 39: 'Olive_sided_Flycatcher', 40: 'Scissor_tailed_Flycatcher', 41: 'Vermilion_Fly

## read labels txt

In [8]:
data_list = []
with open(training_labels_path) as f:
    for line in f.readlines():
        file_name =  line.split(" ")[0]
        label_num =  int(line.split(" ")[1].split(".")[0]) -1
        label_str =  line.split(" ")[1].split(".")[1][:-1]
        data_list.append([file_name, label_num, label_str])

train_data_list = data_list[:int(len(data_list) * 0.8)]
val_data_list = data_list[int(len(data_list) * 0.8):int(len(data_list) * 0.9)]
test_data_list = data_list[int(len(data_list) * 0.9):]

# train_data_list = data_list[:int(len(data_list))]
# val_data_list = data_list[int(len(data_list) * 0.8):int(len(data_list) * 1)]
# test_data_list = data_list[int(len(data_list) * 0.9):]


print("all data : ", len(data_list))
print("train data : ", len(train_data_list))
print("val data : ", len(val_data_list))
print("test data : ", len(test_data_list))
print(train_data_list[:10])

all data :  3000
train data :  2400
val data :  300
test data :  300
[['4283.jpg', 114, 'Brewer_Sparrow'], ['3982.jpg', 161, 'Canada_Warbler'], ['5836.jpg', 143, 'Common_Tern'], ['5980.jpg', 7, 'Rhinoceros_Auklet'], ['4168.jpg', 160, 'Blue_winged_Warbler'], ['2352.jpg', 60, 'Heermann_Gull'], ['0511.jpg', 37, 'Great_Crested_Flycatcher'], ['4492.jpg', 146, 'Least_Tern'], ['1254.jpg', 131, 'White_crowned_Sparrow'], ['2792.jpg', 176, 'Prothonotary_Warbler']]


In [9]:
dataset_train = BirdImageLoader(data_path, train_data_list, class_to_idx, transform=trans_aug)
dataset_val = BirdImageLoader(data_path, val_data_list, class_to_idx, transform=trans)
dataset_test = BirdImageLoader(data_path, test_data_list, class_to_idx, transform=trans)

train_loader = DataLoader(
    dataset_train,
    num_workers=WORKERS,
    batch_size=BATCH_SIZE,
    shuffle=True
)
val_loader = DataLoader(
    dataset_val,
    num_workers=WORKERS,
    batch_size=BATCH_SIZE,
    shuffle=True
)

test_loader = DataLoader(
    dataset_test,
    num_workers=WORKERS,
    batch_size=BATCH_SIZE,
    shuffle=False
)

print('class_to_idx ', len(dataset_train.class_to_idx))
print('val_loader ', len(dataset_train.class_to_idx))
dataset_train.__len__()


class_to_idx  200
val_loader  200


2400

### Test Data loader

In [10]:
dataset_debug = BirdImageLoader(data_path, train_data_list, class_to_idx, transform=trans_aug)
debug_loader = DataLoader(
    dataset_debug,
    num_workers=WORKERS,
    batch_size=BATCH_SIZE,
    shuffle=True
)

for i_batch, image_batch in tqdm(enumerate(debug_loader)):
    x, y = image_batch[0], image_batch[1]
    ROW, COL = 2, 4
    f, ax = plt.subplots(ROW, COL, figsize=(16, 7))
    for i in range(ROW):
        for j in range(COL):
            img = x[i*COL+j]
            ax[i][j].imshow(img.permute(1, 2, 0))
    plt.show()
    break
print(y)
print(y.size())

0it [00:00, ?it/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
0it [00:00, ?it/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-10-1c353b4ce518>", line 17, in <module>
    plt.show()
  File "/opt/conda/lib/python3.7/site-packages/matplotlib/pyplot.py", line 353, in show
    return _backend_mod.show(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/ipykernel/pylab/backend_inline.py", line 43, in show
    metadata=_fetch_figure_metadata(figure_manager.canvas.figure)
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/display.py", line 313, in display
    format_dict, md_dict = format(obj, include=include, exclude=exclude)
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/formatters.py", line 180, in format
    data = formatter(obj)
  File "<decorator-gen-2>", line 2, in __call__
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/formatters.py", line 224, in ca


KeyboardInterrupt



# model

 ### define optimizer, scheduler

In [None]:
# make embedding to length=1
class L2_norm(nn.Module):
    def __init__(self):
        super(L2_norm, self).__init__()

    def forward(self, x):
        return F.normalize(x, p=2, dim=-1)

### fix model

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

## define model

In [None]:
# net = SwinTransformer(
#     hidden_dim=96,
#     layers=(2, 2, 6, 2),
#     heads=(3, 6, 12, 24),
#     channels=3,
#     num_classes=len(dataset_train.class_to_idx),
#     head_dim=32,
#     window_size=7,
#     downscaling_factors=(4, 2, 2, 2),
#     relative_pos_embedding=True
# )
# dummy_x = torch.randn(1, 3, 224, 224)
# logits = net(dummy_x)  # (1,3)
# model = net.to(device)
# print(net)
# print(logits)

In [None]:
backbone = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True)

if pretrain != None:
    backbone = torch.load(pretrain).to(device)
    set_parameter_requires_grad(backbone, False)

projector = nn.Sequential(
    nn.Linear(11221, 2048), nn.BatchNorm1d(2048), nn.ReLU(),
    nn.Linear(2048, 512), nn.BatchNorm1d(512), nn.ReLU(),
    nn.Linear(512, 200)
)
model = nn.Sequential(backbone, projector).to(device)
model_optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum = momentum, weight_decay=weight_decay)

model_scheduler = ReduceLROnPlateau(model_optimizer, 'min')

#### Define loss and evaluation functions

In [None]:
# loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = CrossEntropyLS(label_smooth)
loss_fn = TripletLoss(device)
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

#### Train model

In [None]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [None]:
def update_loss_hist(train_list, val_list, name='result'):
    clear_output(wait=True)
    plt.plot(train_list)
    plt.plot(val_list)
    plt.title(name)
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['train', 'val'], loc='center right')
    plt.savefig('{}/{}.png'.format(output_foloder, name))
    plt.show()

In [None]:
def pass_epoch(loader, mode = 'Train'):
    loss = 0
    acc_top1 = 0
    acc_top5 = 0
    
    for i_batch, image_batch in tqdm(enumerate(loader)):
        x, y = image_batch[0].to(device), image_batch[1].to(device)
        if mode == 'Train':
            model.train()
        elif mode == 'Eval':
            model.eval()
        else:
            print('error model mode!')
        y_pred = model(x)

        loss_batch = loss_fn(y_pred, y)
        loss_batch_acc_top = accuracy(y_pred, y, topk=(1, 5))

        if mode == 'Train':
            model_optimizer.zero_grad()
            scaler.scale(loss_batch).backward()
            scaler.step(model_optimizer)
            scaler.update()
#             loss_batch.backward()
            model_optimizer.step()
        
        loss += loss_batch.detach().cpu()
        acc_top1 += loss_batch_acc_top[0]
        acc_top5 += loss_batch_acc_top[1]
        
    loss /= (i_batch + 1)
    acc_top1 /= (i_batch + 1)
    acc_top5 /= (i_batch + 1)
#     writer.add_scalar(str("loss/" + mode), loss, epoch)
#     writer.add_scalar(str("top1/" + mode), acc_top1, epoch)
#     writer.add_scalar(str("top5/" + mode), acc_top5, epoch)
    
    return loss, acc_top1, acc_top5

In [None]:
from matplotlib import pyplot as plt
from IPython.display import clear_output

train_loss_history = []
train_acc_top1_history = []
train_acc_top5_history = []


val_loss_history = []
val_acc_top1_history = []
val_acc_top5_history = []

In [None]:
torch.save(model, '{}/checkpoint.pth.tar'.format(output_foloder))
scaler = GradScaler()
stop = 0
min_val_loss = 9999
for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)
    train_loss, train_acc_top1, train_acc_top5 = pass_epoch(train_loader, 'Train')  
    with torch.no_grad():
        val_loss, val_acc_top1, val_acc_top5 = pass_epoch(val_loader, 'Eval') 

    writer.add_scalars('loss', {'train':train_loss, 'val':val_loss}, epoch)
    writer.add_scalars('top1', {'train':train_acc_top1, 'val':val_acc_top1}, epoch)
    writer.add_scalars('top5', {'train':train_acc_top5, 'val':val_acc_top5}, epoch)

    
    train_loss_history.append(train_loss)
    train_acc_top1_history.append(train_acc_top1)
    train_acc_top5_history.append(train_acc_top5)


    val_loss_history.append(val_loss)
    val_acc_top1_history.append(val_acc_top1)
    val_acc_top5_history.append(val_acc_top5)
    
    update_loss_hist(train_loss_history, val_loss_history, 'Loss')
    update_loss_hist(train_acc_top5_history, val_acc_top5_history, 'Top5')
    update_loss_hist(train_acc_top1_history, val_acc_top1_history, 'Top1')
    model_scheduler.step(val_loss)
    if (val_loss <= min_val_loss):
        val_loss = min_val_loss
        save_checkpoint({
            'epoch': epoch + 1,
            'learning_rate': learning_rate,
            'loss': 'CrossEntropyLoss',
            'state_dict': model.state_dict(),
        }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(output_foloder, epoch + 1))
    else:
        stop += 1
        if (stop > 5):
            print('early stopping')
            break
torch.save(model, '{}/checkpoint.pth.tar'.format(output_foloder))
torch.cuda.empty_cache()

In [None]:
torch.save(model, '{}/checkpoint.pth.tar'.format(output_foloder))