In [1]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn

import random
import time
from datetime import datetime

import os

import pandas as pd

import timm
from timm.optim import create_optimizer_v2
import torch.nn.functional as F

In [2]:
random.seed(42)
num_classes = 36
using_dataset = f"food-{num_classes}"
# with open(f"./{using_dataset}/meta/train.txt", 'r') as f_train: 
#     correct_images_filepaths = [f"./{using_dataset}/images/{line[:-1]}.jpg" for line in f_train.readlines()]
#     #correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]
#     random.shuffle(correct_images_filepaths)
#     train_images_paths = correct_images_filepaths

# print(len(train_images_paths))

# with open(f"./{using_dataset}/meta/test.txt", 'r') as f_test: 
#     test_images_paths = [f"./{using_dataset}/images/{line[:-1]}.jpg" for line in f_test.readlines()]
#     #test_images_paths = [i for i in test_images_paths if cv2.imread(i) is not None]
#     #random.shuffle(test_images_paths)
#     #test_images_paths = test_images_paths[:500]

# print(len(test_images_paths))
# train_images_paths[:20]

In [3]:
data_transform_train = transforms.Compose(
            [transforms.Resize([256,256]),
             transforms.RandomCrop([224,224]),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
             ])
data_transform_test = transforms.Compose(
            [transforms.Resize([224,224]),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
             ])

ds_train = datasets.ImageFolder(f"{using_dataset}/train", transform=data_transform_train)
ds_test = datasets.ImageFolder(f"{using_dataset}/test", transform=data_transform_test)


# label2id = ds_test.class_to_idx
# id2label = {identifier: label for label, identifier in label2id.items()}

# label2id

In [4]:
BATCH_SIZE = 8
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True)
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False)
dl_train

<torch.utils.data.dataloader.DataLoader at 0x7ff917db6160>

In [5]:
def get_accuracy(predictions, labels):
    return (predictions.argmax(dim=1) == labels).float().mean()

def get_loss(predictions, labels):
    predictions = predictions.reshape(-1, predictions.shape[-1])
    #labels = labels.unsqueeze(1).expand(-1, 1).reshape(-1)
    return F.cross_entropy(predictions, labels)

In [6]:
model = timm.create_model('swin_base_patch4_window7_224_in22k', pretrained=True, num_classes=num_classes)
opt = create_optimizer_v2(model, learning_rate=1e-5)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)

PATH = f"./checkpoints/{using_dataset}/SWin/20211029-071847/model.pt"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
opt.load_state_dict(checkpoint['optimizer_state_dict'])
last_epoch = checkpoint['epoch']
loss = checkpoint['loss']
acc = checkpoint['acc']


model.eval()

cuda


SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): Sequential(
    (0): BasicLayer(
      dim=128, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
     

In [7]:
df_val = pd.DataFrame()
val_labels = []
val_pred = []
losses, accs = [], []
with torch.no_grad():
    for images, labels in dl_test:
        images = images.to(device)
        labels = labels.to(device)
        pred = model(images)
        loss = get_loss(pred, labels)
        acc = get_accuracy(pred, labels)
        accs.append(acc * images.shape[0])            
        losses.append(loss * images.shape[0])
        val_labels.append(labels)
        val_pred.append(pred)

df_val[f"SWin_epoch_2_2_pred"] = val_pred
df_val[f"SWin_epoch_2_2_labels"] = val_labels
df_val.to_csv(f"SWin_epoch_2_2_val.csv")
loss = torch.stack(losses).sum() / len(dl_test.dataset)
acc = torch.stack(accs).sum() / len(dl_test.dataset)

print(f'Epoch: {0+1:>2}    Loss: {loss.item():.3f}    Accuracy: {acc:.3f}')

Epoch:  1    Loss: 0.432    Accuracy: 0.896


In [None]:
EPOCHS = 20
df_val = pd.DataFrame()
df_train = pd.DataFrame()
output_dir = f"./checkpoints/{using_dataset}/SWin/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
os.system(f"mkdir {output_dir}")

for epoch in range(2, EPOCHS):
    losses_test, accs_test = [], []
    model.train()
    train_labels = []
    train_pred = []
    for images, labels in dl_train:
        opt.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        pred = model(images)
        loss = get_loss(pred, labels)
        acc = get_accuracy(pred, labels)
        train_pred.append(pred)
        train_labels.append(labels)
        loss.backward()
        opt.step()
    df_train[f"epoch_{epoch}_pred"] = train_pred
    df_train[f"epoch_{epoch}_labels"] = train_labels
    df_train.to_csv(f"SWin_epoch_{epoch}_train.csv")
    
    model.eval()
    val_labels = []
    val_pred = []
    with torch.no_grad():
        for images, labels in dl_test:
            images = images.to(device)
            labels = labels.to(device)
            pred = model(images)
            loss = get_loss(pred, labels)
            acc = get_accuracy(pred, labels)
            accs_test.append(acc * images.shape[0])            
            losses_test.append(loss * images.shape[0])
            val_labels.append(labels)
            val_pred.append(pred)
            
    df_val[f"SWin_epoch_{epoch}_pred"] = val_pred
    df_val[f"SWin_epoch_{epoch}_labels"] = val_labels
    df_val.to_csv(f"SWin_epoch_{epoch}_val.csv")
    loss = torch.stack(losses_test).sum() / len(dl_test.dataset)
    acc = torch.stack(accs_test).sum() / len(dl_test.dataset)

    print(f'Epoch: {epoch+1:>2}    Loss: {loss.item():.3f}    Accuracy: {acc:.3f}')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': loss.item(),
        'acc': acc
    }, f"{output_dir}/model.pt")

Epoch:  3    Loss: 0.351    Accuracy: 0.911
Epoch:  4    Loss: 0.338    Accuracy: 0.908
