In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.optim as optim
from torchvision.models import resnet18, ResNet18_Weights
import torchvision.transforms.v2 as transforms
import torch.nn as nn
from kvasir_capsule_dataset import get_dataloader
from trainer import train_kc_model
import gc
from custom_transforms import GaussianBlur, RandomChoiceExtended
import os

with_gpu = torch.cuda.is_available()

if with_gpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
print('We are now using %s.' % device)

train_path = r"C:\Users\JadHa\Desktop\Uni\DLMB\DLMI-Project\kvasir-capsule-labeled-images\dataset_train.csv"
val_path = r"C:\Users\JadHa\Desktop\Uni\DLMB\DLMI-Project\kvasir-capsule-labeled-images\dataset_test.csv"
dataset_path = r"C:\Users\JadHa\Desktop\Uni\DLMB\DLMI-Project\kvasir-capsule-labeled-images\labelled_images"

resnet = resnet18().to(device) # weights=ResNet18_Weights.IMAGENET1K_V1

resnet.fc = nn.Linear(512, 2).to(device)

transforms_list = [transforms.Compose([transforms.RandomResizedCrop(size=96),transforms.Resize(size=[336, 336])]),
                   transforms.RandomRotation(degrees=360),
                   GaussianBlur(kernel_size=3),
                   transforms.ColorJitter(0.1, 0.1, 0.003, 0.003)]

transform = RandomChoiceExtended(transforms_list, min_transforms=0, max_transforms=3)

train_loader = get_dataloader(csv_path=train_path, dataset_path=dataset_path, batch_size=128, shuffle=True, transforms=transform, drop_data_till_balanced=False)
val_loader = get_dataloader(csv_path=val_path, dataset_path=dataset_path, batch_size=128, shuffle=True)

optimizer = optim.AdamW(params=resnet.parameters(), lr=3e-4)
scaler = torch.cuda.amp.GradScaler()

criterion = nn.CrossEntropyLoss(weight=torch.tensor([1, 20]).to(device))



We are now using cuda.


# Trying out Augmentations

In [None]:
import matplotlib.pyplot as plt

image = train_loader.dataset[0][0]
image_aug = transform(image)
print("Original shape : %s, Augmented shape : %s"%(image.shape, image_aug.shape))
plt.imshow(torch.cat([image, image_aug], dim=2).permute(1,2,0).numpy())


In [None]:
resnet = resnet18().to(device) # weights=ResNet18_Weights.IMAGENET1K_V1
resnet.fc = nn.Linear(512, 2).to(device)
optimizer = optim.AdamW(params=resnet.parameters(), lr=3e-4)
scaler = torch.cuda.amp.GradScaler()

In [None]:
chkpoint = torch.load(os.path.join("saved_models", "vqvae_vctk_amp_clip1_2.pt"))
resnet.load_state_dict(chkpoint["model_state_dict"])
optimizer.load_state_dict(chkpoint["optimizer"])
scaler.load_state_dict(chkpoint["scaler"])

In [None]:
torch.cuda.empty_cache()
gc.collect()
train_kc_model(resnet, optimizer, criterion, train_loader, val_loader, scaler, model_name="Resnet_AMP", epochs=40, device=device)