In [1]:
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
import wandb
import sys
sys.path.append("../")
import utils
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
batch_size = 32
lr = 1e-4
epochs = 5
# model = torchvision.models.convnext_large(weights="DEFAULT", progress=True)
# model.classifier[2] = nn.Linear(1536, 2)
model = torchvision.models.convnext_tiny(weights="DEFAULT", progress=True)
model.classifier[2] = nn.Linear(768, 2)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, verbose=True)
criterion = nn.CrossEntropyLoss()
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
tf = torchvision.transforms.Compose([transforms.RandomHorizontalFlip(),
                                     transforms.Resize((224,224)),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

full_dataset = utils.train_dataset("../dataset/train.csv", tf=tf)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, validation_dataset = random_split(full_dataset, [train_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, num_workers = 4)
validation_dataloader = DataLoader(validation_dataset, batch_size = batch_size, num_workers = 4)

In [4]:
def train(epoch):
    accuracies = []
    losses = []
    model.train()
    print(f"Running Epoch {epoch}")
    for batch_idx, (data, target) in enumerate(tqdm(train_dataloader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data.float())
        predictions = output.argmax(dim=1, keepdim=True).squeeze()
        correct = (predictions == target).sum().item()
        accuracy = correct / batch_size
        accuracies.append(accuracy)
        loss = F.cross_entropy(output, target)
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        wandb.log({"step loss" : loss.item()})
    
    accuracy = np.array(accuracies).mean()
    loss = np.array(losses).mean()
    wandb.log({"Avg Epoch Train Loss" : loss, "Epoch Train Accuracy" : accuracy})
    print(f"Epoch {epoch} : Avg Train accuracy : {accuracy}, Avg Loss : {loss}")
    return accuracy, loss

def validation():
    model.eval()
    validation_loss = 0
    correct = 0
    for data, target in validation_dataloader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        validation_loss += F.cross_entropy(output, target, reduction="sum").item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    validation_loss /= len(val_loader.dataset)
    print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        validation_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
    wandb.log({"Validation Loss" : validation_loss, "Validation Accuracy" : correct / len(val_loader.dataset)})
    return (correct / len(val_loader.dataset)), validation_loss

In [None]:
wandb.init()
model.to(device)
for epoch in range(1, epochs + 1):
    
    acc, loss = train(epoch)
    train_accuracy.append(acc)
    train_loss.append(loss)
    
    acc, loss = validation()
    scheduler.step(loss)
    torch.save(model.state_dict(), f"../model/{epoch}.pth")

[34m[1mwandb[0m: Currently logged in as: [33matharvbhat[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.13.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Running Epoch 1


  8%|██████▎                                                                          | 36/466 [00:38<05:25,  1.32it/s]