In [None]:
# import cv2
import os
#import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from torchvision.ops import focal_loss
import wandb
from linformer import Linformer
import random
import matplotlib.pyplot as plt      
import glob   
from itertools import chain   
from vit_pytorch.efficient import ViT   
from tqdm.notebook import tqdm   
from __future__ import print_function
# import torch and related libraries 

import torch.nn.functional as F
from torchvision import datasets, transforms   
from torch.optim.lr_scheduler import StepLR   
from torch.utils.data import DataLoader, Dataset

from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score


# CONSTS (might need changing)
data_pth = '/l/users/aleksandr.matsun/stl_balanced/home/aleksandr.matsun/stl_datasets/stl_ds_balanced'
model_name = 'vit'
labels_train_path = 'labels_balanced_train.csv'
labels_valid_path = 'labels_balanced_valid.csv'
batch_size = 256
epochs = 5
lr = 3e-5
gamma = 0.7
#batch_size = 4
device = torch.device('cuda:0')

mask='111'
checkpoint_pth = 'checkpoints_balanced_' + mask
load_weights = False
weights_pth = 'path/to/checkpoint.pth' # !!!!! WEIGHTS MUST BE WRAPPED IN DATAPARALLEL, OTHERWISE CRASHES !!!!!
wandb_project_name = 'stl_' + mask


# WANDB
# wandb.init(project=wandb_project_name)
# wandb.config = {
#         'learning_rate': 0.0001,
#         'epochs': 5,
#         'batch_size': 256
# }

# DATASET

class STLDataset(Dataset):
    def __init__(self, data_dir, lbl_path, mask='111', transforms=None):
        self.mask = mask
        self.data_dir = data_dir
        self.transforms = transforms
        self.df = pd.read_csv(lbl_path, index_col=0)
        self.instruments = list(self.df.columns[2:16])
        
        self.df['tools_present'] = self.df['tools_present'].apply(lambda x: x.strip("[']").split("', '"))
        self.df['tools_present'] = self.df['tools_present'].apply(lambda x: [x[i] for i in range(3) if self.mask[i] == '1'])
        for i in self.instruments:
            self.df[i] = self.df['tools_present'].apply(lambda x: float(i in x))
        
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_name = row['clip_name'] + '/' + row['frame']
        #img_data = np.array(Image.open(self.img_dir + '/' + img_name).convert('RGB'), dtype='float32')
        img_data = Image.open(self.data_dir + '/' + img_name)
        if self.transforms:
            img_data = self.transforms(img_data)
            
        return img_data, torch.from_numpy(self.df[self.instruments].iloc[index].values)

transies = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomApply(torch.nn.ModuleList([transforms.ColorJitter()]), p=0.25),
    transforms.ToTensor(),
    transforms.CenterCrop((615, 900)),
    transforms.Resize((256, 256)),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    transforms.RandomErasing(p=0.2, value='random')
])

transies_v = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop((615, 900)),
    transforms.Resize((256, 256)),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

ds = STLDataset(data_dir=data_pth, 
                lbl_path=labels_train_path, 
                mask=mask,
                transforms=transies)

smol_ds = Subset(ds, np.arange(1000))

ds_v = STLDataset(data_dir=data_pth, 
                lbl_path=labels_valid_path, 
                mask=mask,
                transforms=transies_v)

dl = DataLoader(smol_ds, batch_size=batch_size, shuffle=True, num_workers=4)
# dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
dl_v = DataLoader(ds_v, batch_size=batch_size, shuffle=False, num_workers=4)

# MODEL
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)
print('=' * 10, 'preparing the model', '=' * 10)
if model_name == 'resnet50':
    model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    n_in_f = model.fc.in_features
    model.fc = nn.Linear(n_in_f, 14)
    model = DataParallel(module=model)
elif(model_name == 'vit'):
    model = ViT(
    dim=128,
    image_size=256,
    patch_size=32,
    num_classes=14,
    transformer=efficient_transformer,
    channels=3,
).to(device)
model = DataParallel(module=model)

# loading weights
if load_weights:
    model.load_state_dict(torch.load(weights_pth)) 

#model.to(device)

criterion = focal_loss.sigmoid_focal_loss
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
# TRAIN + VALID CYCLE
for epoch in range(5):
    print('epoch', epoch, 'started')
    
    # TRAIN
    model.train()
    for i, (data_, target_) in enumerate(dl):
        data_, target_ = data_, target_.to(device)
        optimizer.zero_grad()
        outputs = model(data_)
        loss = criterion(outputs, target_, reduction='mean')
        loss.backward()
        optimizer.step()
        if i % 10 == 0:
            print('loss:', loss.data)
            wandb.log({'loss':loss})

    model.eval()
    preds = np.empty((0,14), float)
    for i, (data_, target_) in enumerate(dl_v):
        outputs = model(data_.to(device))
        outputs = torch.sigmoid(outputs).detach().cpu().numpy()
        preds = np.append(preds, outputs, axis=0)
        if preds.shape[0] % 100 == 0:
            print('eval iteration', preds.shape[0])
    target = ds_v.df[ds_v.df.columns[2:16]].values
    f1 = f1_score(target, preds > 0.5, average='macro')
    auc = roc_auc_score(target, preds, average='macro')
    print(f1, auc)
    wandb.log({'val_f1_score':f1})
    wandb.log({'val_auc':auc})
        
    print('epoch', epoch, 'done')
    torch.save(model.state_dict(), checkpoint_pth + '/' + model_name + '_epoch' + str(epoch) + '_' + mask + '.pth')
