In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms.v2 as tfs
import torch.nn.functional as F
from torchinfo import summary

from torchvision.models import swin_t, Swin_T_Weights

import pandas as pd
from PIL import Image
from tqdm import tqdm

import os
import warnings
from IPython.display import clear_output

warnings.filterwarnings('ignore')

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# PATHS
PATH_TO_DATASET = './KeyPointsDataset'

# IMAGES
IMAGE_WIDTH  = 224
IMAGE_HEIGHT = 224

# TRAINING
NUM_CLASSES = 20
BATCH_SIZE  = 64
NUM_EPOCH   = 20
BASE_LR     = 3e-4

### Dataset

In [3]:
class KeyPointsDataset(data.Dataset):
    def __init__(self, path, type, transform=None):
        super().__init__()

        self.path = path
        self.type = type
        self.transform = transform
        self.files = []
        self.targets = []
        self.csv = pd.read_csv(os.path.join(self.path, type, type + '_frames_keypoints.csv'))
        self.len = self.csv.shape[0]

        for i in range(self.len):
            img_name = self.csv.iloc[i]['img_name']
            targets = self.csv.iloc[i].drop('img_name').to_list()

            self.files.append(img_name)
            self.targets.append(targets)

    def __getitem__(self, index):
        image = self.files[index]
        targets = self.targets[index]

        img = Image.open(os.path.join(self.path, self.type, 'images', image))
        target = []

        if (self.transform):
            target = [t * (IMAGE_WIDTH / img.size[0]) for t in targets[:10]]
            target.extend([t * (IMAGE_HEIGHT / img.size[1]) for t in targets[10:]])
            img = self.transform(img)

        target = torch.tensor(target)

        return img, target
    
    def __len__(self):
        return self.len

In [4]:
transforms = tfs.Compose([
    tfs.ToTensor(),
    tfs.Resize((IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=tfs.InterpolationMode.BICUBIC),
    tfs.ToDtype(torch.float32),
    tfs.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

In [5]:
train_dataset = KeyPointsDataset(PATH_TO_DATASET, 'train', transforms)
test_dataset = KeyPointsDataset(PATH_TO_DATASET, 'test', transforms)

### Model changing

In [6]:
model = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1)
model.head = nn.Linear(768, NUM_CLASSES)

for param in model.parameters():
    param.requires_grad = False

for param in model.features[-2:].parameters():
    param.requires_grad = True

for param in model.head.parameters():
    param.requires_grad = True

model.to(device)
clear_output()

### Finetuning

In [10]:
train_dataloader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def train_model(model, train_data, test_data, save=None):
    opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), BASE_LR)
    loss_func = nn.MSELoss()
    
    test_loss = 0
    model.train()
    for _e in range(NUM_EPOCH):
        # ----- training -----
        for x_batch, y_batch in tqdm(train_data, desc=f'Epoch {_e + 1}/{NUM_EPOCH}, test loss: {test_loss}', leave=False):
            x_batch, y_batch = x_batch.to(device).float(), y_batch.to(device).float()

            # gradient descent
            opt.zero_grad()
            output = model(x_batch)
            loss = loss_func(output, y_batch)
            
            loss.backward()
            opt.step()

        test_loss = 0
        model.eval()
        with torch.no_grad():
            for x_test, y_test in test_data:
                x_test, y_test = x_test.to(device), y_test.to(device)

                # get classification
                output = model(x_test)
                test_loss += F.mse_loss(y_test, output)

    # ----- saving -----
    if (save):
        torch.save(model.state_dict(), 'models/' + save + '.tar')

In [None]:
train_model(model, train_dataloader, test_dataloader, 'ft_swin_t')