In [2]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from tqdm.notebook import tqdm
import os
import torchmetrics
from PIL import Image
import sys
sys.path.append('../../src/trainer')
from trainer import Trainer
sys.path.append('../../src/models')
from resnet18 import ResNet18
sys.path.append('../../src/transforms')
from collate_fn import collate_fn
from label_mapper import LabelMapper
from resizer import ImageResizer


In [3]:
BATCH_SIZE = 3
NUM_CLASSES = 3
LEARNING_RATE = 0.001
EPOCHS = 1
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [10]:
trf = transforms.Compose([
    ImageResizer(),
    transforms.ToTensor(),
])

tar = LabelMapper({
    0:0, # 0 is the label for benign (BY)
    1:0, 
    2:0,
    3:1, # 1 is the label for atypical (AT)
    4:1,
    5:2, # 2 is the label for malignant (MT)
    6:2,
})

dataset = datasets.ImageFolder(root="D:\\AIDS\\S2\\Project\\Dataset\\BRACS_RoI\\latest_version\\val", transform=trf, target_transform=tar)
            
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [11]:
model = ResNet18(n_classes=NUM_CLASSES).to(DEVICE)

In [12]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
accuracy = torchmetrics.Accuracy(num_classes=NUM_CLASSES, task='multiclass').to(DEVICE)

trainer = Trainer(optimizer=optimizer, loss=criterion, device=DEVICE, metrics={'accuracy': accuracy})

trainer.train(model, dataloader, epochs=EPOCHS)

  0%|          | 0/104 [00:00<?, ?it/s]

In [None]:
torch.cuda.empty_cache()

In [13]:
trainer.history

{'train': {'accuracy': [0.4226706902890538], 'loss': [1.1665935573669581]},
 'val': {'accuracy': [], 'loss': []}}