In [1]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn

import random
import time
from datetime import datetime

import os

import pandas as pd

import timm
from timm.optim import create_optimizer_v2
import torch.nn.functional as F

from collections import OrderedDict

In [2]:
random.seed(42)
num_classes = 36
using_dataset = f"food-{num_classes}"
using_transformer = "CSWin"

models_transformers = {
    'ViT': 'vit_base_patch16_224_in21k',
    'CSWin': 'CSWin_96_24322_base_224',
    'BEiT': 'beit_base_patch16_224_in22k',
    'SWin': 'swin_base_patch4_window7_224_in22k',
    'DeiT': 'deit_tiny_distilled_patch16_224',
}
# with open(f"./{using_dataset}/meta/train.txt", 'r') as f_train: 
#     correct_images_filepaths = [f"./{using_dataset}/images/{line[:-1]}.jpg" for line in f_train.readlines()]
#     #correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]
#     random.shuffle(correct_images_filepaths)
#     train_images_paths = correct_images_filepaths

# print(len(train_images_paths))

# with open(f"./{using_dataset}/meta/test.txt", 'r') as f_test: 
#     test_images_paths = [f"./{using_dataset}/images/{line[:-1]}.jpg" for line in f_test.readlines()]
#     #test_images_paths = [i for i in test_images_paths if cv2.imread(i) is not None]
#     #random.shuffle(test_images_paths)
#     #test_images_paths = test_images_paths[:500]

# print(len(test_images_paths))
# train_images_paths[:20]

In [3]:
data_transform_train = transforms.Compose(
            [transforms.Resize([256,256]),
             transforms.RandomCrop([224,224]),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
             ])
data_transform_test = transforms.Compose(
            [transforms.Resize([224,224]),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
             ])

ds_train = datasets.ImageFolder(f"{using_dataset}/train", transform=data_transform_train)
ds_test = datasets.ImageFolder(f"{using_dataset}/test", transform=data_transform_test)


# label2id = ds_test.class_to_idx
# id2label = {identifier: label for label, identifier in label2id.items()}

# label2id

In [4]:
BATCH_SIZE = 8
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True)
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False)
dl_train

<torch.utils.data.dataloader.DataLoader at 0x7fcc43cbba00>

In [5]:
def get_accuracy(predictions, labels):
    return (predictions.argmax(dim=1) == labels).float().mean()

def get_loss(predictions, labels):
    predictions = predictions.reshape(-1, predictions.shape[-1])
    #labels = labels.unsqueeze(1).expand(-1, 1).reshape(-1)
    return F.cross_entropy(predictions, labels)
def load_checkpoint(model, checkpoint_path):
    if checkpoint_path and os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        state_dict_key = 'state_dict_ema'
        model_key = 'model'
        if state_dict_key and state_dict_key in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint[state_dict_key].items():
                # strip `module.` prefix
                name = k[7:] if k.startswith('module') else k
                if 'head' in k:
                    continue
                new_state_dict[name] = v
            state_dict = new_state_dict
    else:
        raise FileNotFoundError()
    model_dict = model.state_dict()
    pretrained_dict = state_dict
    loaded_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    load_dic = [k for k, v in pretrained_dict.items() if k in model_dict]
    miss_dic = [k for k, v in pretrained_dict.items() if not (k in model_dict)]
    unexpect_dic = [k for k, v in model_dict.items() if not (k in pretrained_dict)]
    model_dict.update(loaded_dict)
    model.load_state_dict(model_dict, strict=True)
    return model

In [6]:
model = timm.create_model(models_transformers[using_transformer], pretrained=False, num_classes=num_classes)
opt = create_optimizer_v2(model, lr=1e-3)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if using_transformer == 'CSWin':
    old_dict = model.state_dict()
    print(old_dict)
    model = load_checkpoint(model, 'cswin_base_224.pth')
    new_dict = model.state_dict()
    print("==============================================")
    print(new_dict)
#     checkpoint = torch.load('upernet_cswin_base.pth')
#     print(checkpoint)
#     model.load_state_dict(checkpoint['state_dict'])

# PATH = f"./results/{using_dataset}/SWin/20211029-071847/model.pt"
# checkpoint = torch.load(PATH)
# model.load_state_dict(checkpoint['model_state_dict'])
# opt.load_state_dict(checkpoint['optimizer_state_dict'])
# last_epoch = checkpoint['epoch']
# loss = checkpoint['loss']
# acc = checkpoint['acc']

model = model.to(device)
model.eval()

OrderedDict([('stage1_conv_embed.0.weight', tensor([[[[ 0.0355,  0.0141,  0.0019,  ..., -0.0129,  0.0284, -0.0221],
          [ 0.0755,  0.0030, -0.0633,  ...,  0.0164,  0.0684, -0.0251],
          [ 0.0798,  0.0758,  0.0484,  ..., -0.0489, -0.0040,  0.0131],
          ...,
          [ 0.0276,  0.0470, -0.0304,  ..., -0.0391, -0.0617, -0.0562],
          [ 0.0637, -0.0361, -0.0524,  ...,  0.0108,  0.0672,  0.0734],
          [-0.0706, -0.0397,  0.0802,  ...,  0.0632, -0.0525, -0.0009]],

         [[-0.0329, -0.0366, -0.0452,  ..., -0.0732, -0.0444, -0.0386],
          [ 0.0175,  0.0034,  0.0628,  ...,  0.0567, -0.0475, -0.0316],
          [ 0.0082, -0.0496, -0.0029,  ...,  0.0223,  0.0631, -0.0613],
          ...,
          [-0.0004, -0.0485,  0.0732,  ..., -0.0058, -0.0396,  0.0496],
          [-0.0773, -0.0089, -0.0324,  ...,  0.0449, -0.0795, -0.0180],
          [-0.0149,  0.0516,  0.0070,  ...,  0.0601, -0.0399, -0.0594]],

         [[ 0.0377, -0.0426,  0.0721,  ..., -0.0059,  0.06

OrderedDict([('stage1_conv_embed.0.weight', tensor([[[[-2.8043e-03, -7.2520e-03, -5.4653e-03,  ..., -2.6249e-02,
           -3.0216e-02, -1.6165e-02],
          [ 1.8042e-03,  7.1927e-04,  4.3801e-03,  ..., -1.9118e-02,
           -3.4291e-02, -2.8386e-02],
          [ 4.9256e-03, -3.1153e-03,  6.1453e-03,  ..., -2.8922e-02,
           -3.7641e-02, -4.4335e-02],
          ...,
          [ 2.4913e-02,  1.7818e-02,  3.5070e-02,  ...,  2.4015e-04,
           -9.2273e-03, -1.6795e-02],
          [ 2.5416e-02,  2.3716e-02,  3.9814e-02,  ...,  8.3865e-03,
           -1.5818e-03, -1.0439e-02],
          [ 1.9004e-02,  2.0107e-02,  3.6733e-02,  ...,  1.4813e-02,
           -5.1330e-04, -3.5696e-03]],

         [[ 5.5214e-03, -3.6639e-04, -1.8163e-03,  ..., -2.5167e-02,
           -3.0086e-02, -1.1933e-02],
          [ 1.1769e-02,  1.0062e-02,  1.3095e-02,  ..., -1.4434e-02,
           -3.1939e-02, -2.3743e-02],
          [ 1.6053e-02,  7.8624e-03,  1.3417e-02,  ..., -2.4053e-02,
           -3.

CSWinTransformer(
  (stage1_conv_embed): Sequential(
    (0): Conv2d(3, 96, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
    (1): Rearrange('b c h w -> b (h w) c', h=56, w=56)
    (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (stage1): ModuleList(
    (0): CSWinBlock(
      (qkv): Linear(in_features=96, out_features=288, bias=True)
      (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (proj): Linear(in_features=96, out_features=96, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
      (attns): ModuleList(
        (0): LePEAttention(
          (get_v): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)
          (attn_drop): Dropout(p=0.0, inplace=False)
        )
        (1): LePEAttention(
          (get_v): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)
          (attn_drop): Dropout(p=0.0, inplace=False)
        )
      )
      (drop_path): Identity()
      (mlp): Mlp(
   

In [7]:
for i in model.state_dict():
    print(f"{i}: {model.state_dict()[i].size()}")
model.state_dict()

stage1_conv_embed.0.weight: torch.Size([96, 3, 7, 7])
stage1_conv_embed.0.bias: torch.Size([96])
stage1_conv_embed.2.weight: torch.Size([96])
stage1_conv_embed.2.bias: torch.Size([96])
stage1.0.qkv.weight: torch.Size([288, 96])
stage1.0.qkv.bias: torch.Size([288])
stage1.0.norm1.weight: torch.Size([96])
stage1.0.norm1.bias: torch.Size([96])
stage1.0.proj.weight: torch.Size([96, 96])
stage1.0.proj.bias: torch.Size([96])
stage1.0.attns.0.get_v.weight: torch.Size([48, 1, 3, 3])
stage1.0.attns.0.get_v.bias: torch.Size([48])
stage1.0.attns.1.get_v.weight: torch.Size([48, 1, 3, 3])
stage1.0.attns.1.get_v.bias: torch.Size([48])
stage1.0.mlp.fc1.weight: torch.Size([384, 96])
stage1.0.mlp.fc1.bias: torch.Size([384])
stage1.0.mlp.fc2.weight: torch.Size([96, 384])
stage1.0.mlp.fc2.bias: torch.Size([96])
stage1.0.norm2.weight: torch.Size([96])
stage1.0.norm2.bias: torch.Size([96])
stage1.1.qkv.weight: torch.Size([288, 96])
stage1.1.qkv.bias: torch.Size([288])
stage1.1.norm1.weight: torch.Size([96]

stage3.6.norm2.weight: torch.Size([384])
stage3.6.norm2.bias: torch.Size([384])
stage3.7.qkv.weight: torch.Size([1152, 384])
stage3.7.qkv.bias: torch.Size([1152])
stage3.7.norm1.weight: torch.Size([384])
stage3.7.norm1.bias: torch.Size([384])
stage3.7.proj.weight: torch.Size([384, 384])
stage3.7.proj.bias: torch.Size([384])
stage3.7.attns.0.get_v.weight: torch.Size([192, 1, 3, 3])
stage3.7.attns.0.get_v.bias: torch.Size([192])
stage3.7.attns.1.get_v.weight: torch.Size([192, 1, 3, 3])
stage3.7.attns.1.get_v.bias: torch.Size([192])
stage3.7.mlp.fc1.weight: torch.Size([1536, 384])
stage3.7.mlp.fc1.bias: torch.Size([1536])
stage3.7.mlp.fc2.weight: torch.Size([384, 1536])
stage3.7.mlp.fc2.bias: torch.Size([384])
stage3.7.norm2.weight: torch.Size([384])
stage3.7.norm2.bias: torch.Size([384])
stage3.8.qkv.weight: torch.Size([1152, 384])
stage3.8.qkv.bias: torch.Size([1152])
stage3.8.norm1.weight: torch.Size([384])
stage3.8.norm1.bias: torch.Size([384])
stage3.8.proj.weight: torch.Size([384, 3

stage3.18.mlp.fc2.bias: torch.Size([384])
stage3.18.norm2.weight: torch.Size([384])
stage3.18.norm2.bias: torch.Size([384])
stage3.19.qkv.weight: torch.Size([1152, 384])
stage3.19.qkv.bias: torch.Size([1152])
stage3.19.norm1.weight: torch.Size([384])
stage3.19.norm1.bias: torch.Size([384])
stage3.19.proj.weight: torch.Size([384, 384])
stage3.19.proj.bias: torch.Size([384])
stage3.19.attns.0.get_v.weight: torch.Size([192, 1, 3, 3])
stage3.19.attns.0.get_v.bias: torch.Size([192])
stage3.19.attns.1.get_v.weight: torch.Size([192, 1, 3, 3])
stage3.19.attns.1.get_v.bias: torch.Size([192])
stage3.19.mlp.fc1.weight: torch.Size([1536, 384])
stage3.19.mlp.fc1.bias: torch.Size([1536])
stage3.19.mlp.fc2.weight: torch.Size([384, 1536])
stage3.19.mlp.fc2.bias: torch.Size([384])
stage3.19.norm2.weight: torch.Size([384])
stage3.19.norm2.bias: torch.Size([384])
stage3.20.qkv.weight: torch.Size([1152, 384])
stage3.20.qkv.bias: torch.Size([1152])
stage3.20.norm1.weight: torch.Size([384])
stage3.20.norm1.

stage3.31.mlp.fc1.weight: torch.Size([1536, 384])
stage3.31.mlp.fc1.bias: torch.Size([1536])
stage3.31.mlp.fc2.weight: torch.Size([384, 1536])
stage3.31.mlp.fc2.bias: torch.Size([384])
stage3.31.norm2.weight: torch.Size([384])
stage3.31.norm2.bias: torch.Size([384])
merge3.conv.weight: torch.Size([768, 384, 3, 3])
merge3.conv.bias: torch.Size([768])
merge3.norm.weight: torch.Size([768])
merge3.norm.bias: torch.Size([768])
stage4.0.qkv.weight: torch.Size([2304, 768])
stage4.0.qkv.bias: torch.Size([2304])
stage4.0.norm1.weight: torch.Size([768])
stage4.0.norm1.bias: torch.Size([768])
stage4.0.proj.weight: torch.Size([768, 768])
stage4.0.proj.bias: torch.Size([768])
stage4.0.attns.0.get_v.weight: torch.Size([768, 1, 3, 3])
stage4.0.attns.0.get_v.bias: torch.Size([768])
stage4.0.mlp.fc1.weight: torch.Size([3072, 768])
stage4.0.mlp.fc1.bias: torch.Size([3072])
stage4.0.mlp.fc2.weight: torch.Size([768, 3072])
stage4.0.mlp.fc2.bias: torch.Size([768])
stage4.0.norm2.weight: torch.Size([768])
s

OrderedDict([('stage1_conv_embed.0.weight',
              tensor([[[[-2.8043e-03, -7.2520e-03, -5.4653e-03,  ..., -2.6249e-02,
                         -3.0216e-02, -1.6165e-02],
                        [ 1.8042e-03,  7.1927e-04,  4.3801e-03,  ..., -1.9118e-02,
                         -3.4291e-02, -2.8386e-02],
                        [ 4.9256e-03, -3.1153e-03,  6.1453e-03,  ..., -2.8922e-02,
                         -3.7641e-02, -4.4335e-02],
                        ...,
                        [ 2.4913e-02,  1.7818e-02,  3.5070e-02,  ...,  2.4015e-04,
                         -9.2273e-03, -1.6795e-02],
                        [ 2.5416e-02,  2.3716e-02,  3.9814e-02,  ...,  8.3865e-03,
                         -1.5818e-03, -1.0439e-02],
                        [ 1.9004e-02,  2.0107e-02,  3.6733e-02,  ...,  1.4813e-02,
                         -5.1330e-04, -3.5696e-03]],
              
                       [[ 5.5214e-03, -3.6639e-04, -1.8163e-03,  ..., -2.5167e-02,
                  

In [8]:
df_val = pd.DataFrame()
val_labels = []
val_pred = []
losses, accs = [], []
with torch.no_grad():
    for images, labels in dl_test:
        images = images.to(device)
        labels = labels.to(device)
        pred = model(images)
        loss = get_loss(pred, labels)
        acc = get_accuracy(pred, labels)
        accs.append(acc * images.shape[0])            
        losses.append(loss * images.shape[0])
        val_labels.append(labels.cpu().detach().numpy())
        val_pred.append(pred.cpu().detach().numpy())

df_val[f"{using_transformer}_epoch_0_pred"] = val_pred
df_val[f"{using_transformer}_epoch_0_labels"] = val_labels
df_val.to_csv(f"results/{using_transformer}_epoch_0_val.csv")
loss = torch.stack(losses).sum() / len(dl_test.dataset)
acc = torch.stack(accs).sum() / len(dl_test.dataset)

print(f'Epoch: 0    Loss: {loss.item():.3f}    Accuracy: {acc:.3f}')

output_dir = f"./checkpoints/{using_dataset}/{using_transformer}/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
os.system(f"mkdir {output_dir}")

torch.save({
        'epoch': 0,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': loss.item(),
        'acc': acc
    }, f"{output_dir}/model_0.pt")

Epoch: 0    Loss: 3.641    Accuracy: 0.029


In [None]:
EPOCHS = 20
df_val = pd.DataFrame()
df_train = pd.DataFrame()

previous_loss = 100

for epoch in range(EPOCHS):
    losses_test, accs_test = [], []
    model.train()
    train_labels = []
    train_pred = []
    for images, labels in dl_train:
        opt.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        pred = model(images)
        if using_transformer == 'DeiT':
            pred = pred[0]
        loss = get_loss(pred, labels)
        acc = get_accuracy(pred, labels)
        train_pred.append(pred.cpu().detach().numpy())
        train_labels.append(labels.cpu().detach().numpy())
        loss.backward()
        opt.step()
    df_train[f"epoch_{epoch+1}_pred"] = train_pred
    df_train[f"epoch_{epoch+1}_labels"] = train_labels
    df_train.to_csv(f"results/{using_transformer}_epoch_{epoch+1}_train.csv")
    
    model.eval()
    val_labels = []
    val_pred = []
    with torch.no_grad():
        for images, labels in dl_test:
            images = images.to(device)
            labels = labels.to(device)
            pred = model(images)
            loss = get_loss(pred, labels)
            acc = get_accuracy(pred, labels)
            accs_test.append(acc * images.shape[0])            
            losses_test.append(loss * images.shape[0])
            val_labels.append(labels.cpu().detach().numpy())
            val_pred.append(pred.cpu().detach().numpy())
            
    df_val[f"{using_transformer}_epoch_{epoch}_pred"] = val_pred
    df_val[f"{using_transformer}_epoch_{epoch}_labels"] = val_labels
    df_val.to_csv(f"results/{using_transformer}_epoch_{epoch}_val.csv")
    loss = torch.stack(losses_test).sum() / len(dl_test.dataset)
    acc = torch.stack(accs_test).sum() / len(dl_test.dataset)
    
    try:
        os.system(f"rm results/{using_transformer}_epoch_{epoch-2}_train.csv")
        os.system(f"rm results/{using_transformer}_epoch_{epoch-2}_val.csv")
    except:
        continue
    
    print(f'Epoch: {epoch+1:>2}    Loss: {loss.item():.3f}    Accuracy: {acc:.3f}')
    if previous_loss >= loss.item():
        previous_loss = loss.item()
    elif previous_loss < loss.item()-0.5:
        break
        
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': loss.item(),
        'acc': acc
    }, f"{output_dir}/model_{epoch}.pt")

Epoch:  1    Loss: 0.450    Accuracy: 0.862
Epoch:  2    Loss: 0.379    Accuracy: 0.885
Epoch:  3    Loss: 0.351    Accuracy: 0.896
Epoch:  4    Loss: 0.444    Accuracy: 0.884
Epoch:  5    Loss: 0.399    Accuracy: 0.897
Epoch:  6    Loss: 0.471    Accuracy: 0.889
Epoch:  7    Loss: 0.399    Accuracy: 0.909
Epoch:  8    Loss: 0.424    Accuracy: 0.905
