In [1]:
import torch
import torch.nn as nn
from dataset import MyDataset
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score, cohen_kappa_score
import numpy as np
import os
import shutil
from tqdm.notebook import tqdm

In [2]:
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_transform = Compose([ToTensor(), 
                               # RandomAffine(degrees=(-5, 5), translate=(0.15, 0.15), scale=(0.85, 1), shear=10),
                               Resize((224, 224)), 
                               # ColorJitter(brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05),
                               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
    test_transform = Compose([ToTensor(),
                              Resize((224, 224)),
                              Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    train_dataset = MyDataset(data_path="data/train_images", csv_file="train-1.csv", is_train=True, transform=train_transform)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=8, drop_last=True, shuffle=True, num_workers=4)
    val_dataset = MyDataset(data_path="data/train_images", csv_file="train-2.csv", is_train=False, transform=test_transform)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size=8, drop_last=True, shuffle=True, num_workers=4)
    
    model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
    model.classifier[3] = nn.Linear(1280, 5)
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-10, momentum=0.9)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3, betas=(0.9, 0.999))
    
    if os.path.exists("trained_models/mobilenet_v3/last.pt"):
        checkpoint = torch.load("trained_models/mobilenet_v3/last.pt")
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        start_epoch = checkpoint["epoch"]
        best_acc = checkpoint["best_acc"]
    else:
        start_epoch = 0
        best_acc = -1
        
    if os.path.isdir("tensorboard/mobilenet_v3"):
        shutil.rmtree("tensorboard/mobilenet_v3")
    os.makedirs("tensorboard/mobilenet_v3")
    # len(train_dataloader) = 1831, num_iters_per_epoch = 229 when batch_size = 8
    num_iters_per_epoch = len(train_dataloader)
    print(num_iters_per_epoch)
    writer = SummaryWriter("tensorboard/mobilenet_v3")
    for epoch in range(start_epoch, 400):

        # TRAINING
        model.train()
        train_loss = []
        train_labels = []
        train_predictions = []
        avg_loss= 0
        progress_bar = tqdm(train_dataloader, colour="green")
        for iter, (images, labels) in enumerate(progress_bar):
            images = images.to(device)
            labels = labels.to(device)
            predictions = model(images)
            loss = criterion(predictions, labels)
            predictions_class = torch.argmax(predictions, dim=1)
            train_labels.extend(labels.tolist())
            train_predictions.extend(predictions_class.tolist())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            avg_loss = np.mean(train_loss)
            progress_bar.set_description(" Epoch: {}/400. Loss: {:.5f}".format(epoch+1, avg_loss))
            writer.add_scalar("Train/Loss", avg_loss, epoch * num_iters_per_epoch + iter)
        acc_train = accuracy_score(train_labels, train_predictions)
        cohen_train = cohen_kappa_score(train_labels, train_predictions)
        print("Avg_loss_train: {:.5f}, Acc_train: {:.5f}, Cohen_train: {:.5f}".format(avg_loss, acc_train, cohen_train))
        
        # VALIDATION
        model.eval()
        progress_bar = tqdm(val_dataloader, colour="yellow")
        all_labels = []
        all_predictions = []
        all_losses = []
        with torch.no_grad():
            for images, labels in progress_bar:
                images = images.to(device)
                labels = labels.to(device)
                prediction = model(images)
                loss = criterion(prediction, labels)
                predicted_class = torch.argmax(prediction, dim=1)
                all_labels.extend(labels.tolist())
                all_predictions.extend(predicted_class.tolist())
                all_losses.append(loss.item())
        loss = np.mean(all_losses)
        acc = accuracy_score(all_predictions, all_labels)
        cohenscore = cohen_kappa_score(all_predictions, all_labels)
        print("Val_loss: {:.5f}, Val_acc: {:.5f}, Val_cohen_kappa_score: {:.5f}".format(loss, acc, cohenscore))
        writer.add_scalar("Val/Loss", loss, epoch)
        writer.add_scalar("Val/Acc", acc, epoch)

        # Save checkpoint
        checkpoint = {
            "epoch": epoch+1,
            "best_acc": best_acc,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        torch.save(checkpoint, os.path.join("trained_models/mobilenet_v3", "last.pt"))
        if acc > best_acc:
            best_acc = acc
            torch.save(checkpoint, os.path.join("trained_models/mobilenet_v3", "best.pt"))

if __name__ == '__main__':
    train()


  checkpoint = torch.load("trained_models/mobilenet_v3/last.pt")


364


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02771, Acc_train: 0.98626, Cohen_train: 0.97928


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.66482, Val_acc: 0.78846, Val_cohen_kappa_score: 0.67670


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02442, Acc_train: 0.98764, Cohen_train: 0.98135


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.69092, Val_acc: 0.78709, Val_cohen_kappa_score: 0.67274


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02197, Acc_train: 0.98695, Cohen_train: 0.98030


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.76289, Val_acc: 0.78846, Val_cohen_kappa_score: 0.67509


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02588, Acc_train: 0.98729, Cohen_train: 0.98085


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.68137, Val_acc: 0.79121, Val_cohen_kappa_score: 0.68102


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02518, Acc_train: 0.98901, Cohen_train: 0.98343


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.77597, Val_acc: 0.77198, Val_cohen_kappa_score: 0.64784


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.06289, Acc_train: 0.97665, Cohen_train: 0.96473


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.75667, Val_acc: 0.78297, Val_cohen_kappa_score: 0.67347


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.07235, Acc_train: 0.97562, Cohen_train: 0.96322


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.43806, Val_acc: 0.78709, Val_cohen_kappa_score: 0.67846


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.03141, Acc_train: 0.98523, Cohen_train: 0.97774


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.72209, Val_acc: 0.78022, Val_cohen_kappa_score: 0.66615


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02285, Acc_train: 0.98935, Cohen_train: 0.98395


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.54009, Val_acc: 0.78984, Val_cohen_kappa_score: 0.68245


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02044, Acc_train: 0.98832, Cohen_train: 0.98240


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.70433, Val_acc: 0.79670, Val_cohen_kappa_score: 0.69170


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02060, Acc_train: 0.98764, Cohen_train: 0.98135


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.59493, Val_acc: 0.80082, Val_cohen_kappa_score: 0.69657


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.01899, Acc_train: 0.98867, Cohen_train: 0.98293


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.74102, Val_acc: 0.79670, Val_cohen_kappa_score: 0.68825


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02354, Acc_train: 0.99038, Cohen_train: 0.98550


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.60097, Val_acc: 0.80220, Val_cohen_kappa_score: 0.69784


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.01971, Acc_train: 0.98970, Cohen_train: 0.98444


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.52610, Val_acc: 0.79945, Val_cohen_kappa_score: 0.69435


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.03787, Acc_train: 0.98214, Cohen_train: 0.97302


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.72818, Val_acc: 0.73764, Val_cohen_kappa_score: 0.60828


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.07810, Acc_train: 0.97562, Cohen_train: 0.96323


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.75683, Val_acc: 0.78846, Val_cohen_kappa_score: 0.67555


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.04507, Acc_train: 0.98043, Cohen_train: 0.97049


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.41794, Val_acc: 0.79121, Val_cohen_kappa_score: 0.68133


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.03294, Acc_train: 0.98729, Cohen_train: 0.98085


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.47838, Val_acc: 0.79670, Val_cohen_kappa_score: 0.68676


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02237, Acc_train: 0.98935, Cohen_train: 0.98395


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.62315, Val_acc: 0.79258, Val_cohen_kappa_score: 0.68515


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02520, Acc_train: 0.98798, Cohen_train: 0.98188


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.46282, Val_acc: 0.78571, Val_cohen_kappa_score: 0.67099


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.01864, Acc_train: 0.99004, Cohen_train: 0.98498


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.54972, Val_acc: 0.78846, Val_cohen_kappa_score: 0.67581


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.01964, Acc_train: 0.99073, Cohen_train: 0.98602


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.59972, Val_acc: 0.77747, Val_cohen_kappa_score: 0.66396


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.04780, Acc_train: 0.98317, Cohen_train: 0.97465


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.59016, Val_acc: 0.77335, Val_cohen_kappa_score: 0.64756


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.04367, Acc_train: 0.98249, Cohen_train: 0.97358


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.73160, Val_acc: 0.77747, Val_cohen_kappa_score: 0.65643


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02740, Acc_train: 0.98729, Cohen_train: 0.98083


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.73015, Val_acc: 0.76648, Val_cohen_kappa_score: 0.64736


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.06355, Acc_train: 0.97699, Cohen_train: 0.96530


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.59310, Val_acc: 0.78159, Val_cohen_kappa_score: 0.67208


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.03659, Acc_train: 0.98386, Cohen_train: 0.97565


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.52154, Val_acc: 0.79945, Val_cohen_kappa_score: 0.69503


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02650, Acc_train: 0.98523, Cohen_train: 0.97773


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.68417, Val_acc: 0.78571, Val_cohen_kappa_score: 0.67839


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02056, Acc_train: 0.98867, Cohen_train: 0.98292


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.58797, Val_acc: 0.80220, Val_cohen_kappa_score: 0.69917


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.02169, Acc_train: 0.98970, Cohen_train: 0.98446


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.55620, Val_acc: 0.79808, Val_cohen_kappa_score: 0.69236


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.05155, Acc_train: 0.98283, Cohen_train: 0.97408


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.82835, Val_acc: 0.76786, Val_cohen_kappa_score: 0.65002


  0%|          | 0/364 [00:00<?, ?it/s]

Avg_loss_train: 0.04760, Acc_train: 0.98043, Cohen_train: 0.97049


  0%|          | 0/91 [00:00<?, ?it/s]

Val_loss: 1.95362, Val_acc: 0.78709, Val_cohen_kappa_score: 0.67568


  0%|          | 0/364 [00:00<?, ?it/s]


KeyboardInterrupt

