In [None]:
import warnings
warnings.filterwarnings('ignore')

import gc
import timm
import wandb
import random
import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader 
from torch.cuda.amp import autocast, GradScaler
from torchvision.datasets import ImageFolder

from sklearn.model_selection import train_test_split

from utils.utils import ImageLoader, TransformsCE


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(seed)

In [None]:
SEED = 333
seed_everything(SEED)

IMG_SIZE = 224
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 10

RUN_NAME = 'Classification-Example'
WANDB_PRJ = 'public'

WANDB_CONFIG = {
    'seed': SEED,
    'model': 'resnet50',
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,  
}

In [None]:
dataset = ImageFolder("./datasets/cat-and-dog/training_set/training_set/")
trn_data, val_data, trn_label, val_label = train_test_split(dataset.imgs, dataset.targets, test_size=0.2, random_state=SEED)

trn_ds, val_ds = list(map(lambda x, y: ImageLoader(dataset=x, phase=y, transform=TransformsCE(IMG_SIZE)), [trn_data, val_data], ['train', 'valid']))
trn_dl, val_dl = list(map(lambda x, y: DataLoader(x, batch_size=BATCH_SIZE, shuffle=y, drop_last=True), [trn_ds, val_ds], [True, False]))

img_datasets = {'train' : trn_ds, 'valid': val_ds}
dataloaders = {'train': trn_dl, 'valid': val_dl}

dataset_sizes = {x: len(img_datasets[x]) for x in ['train', 'valid']}

In [None]:
model = timm.create_model('resnet50', pretrained=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
scaler = GradScaler()

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'model running on {device}')

wandb.init(name=RUN_NAME, project=WANDB_PRJ, config=WANDB_CONFIG, reinit=True)

model = model.to(device)

for e in range(EPOCHS):
    
    gc.collect()
    torch.cuda.empty_cache()
    running_log = {'epoch': e+1}
    
    for phase in ['train', 'valid']: 
        running_loss = 0.0
        running_corrects = 0
        
        model.train() if phase == 'train' else model.eval()

        for idx, (features, labels) in enumerate(dataloaders[phase]):
            features = features.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
        
            with torch.set_grad_enabled(phase=='train'):
                with autocast():
                    logits = model(features)
                    _, preds = torch.max(logits, 1)
                    loss = criterion(logits, labels)
                    
                    if phase == 'train':
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                        
                    running_loss += loss.item() * features.size(0)
                    running_corrects += torch.sum(preds == labels.data).item()
                        
        if phase == 'train' and e >= 10:
            scheduler.step()
    
        epoch_loss = running_loss / image_datasets['phase']
        epoch_acc = running_corrects / image_datasets['phase']
        
        running_log.update({f'{phase}_loss': epoch_loss, f'{phase}_acc': epoch_acc})
        print(f'epoch {e+1} {phase} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')
        
    wandb.log(running_log)

wandb.finish()