In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.utils.data
import seaborn as sns
# import wandb
from tqdm import tqdm
from torchvision.transforms import v2
import torch.nn as nn
from dataclasses import dataclass
#from train import train_with_metrics


In [2]:
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu' if torch.cuda.is_available() else 'cuda')
device = 'cpu'
print(device)
torch.set_default_device(device)

cpu


In [3]:
from classification_dataset import MyDataset

In [4]:
# !pip3 install wandb -qU

In [5]:
# wandb.login()

In [6]:
train_dataset = MyDataset(base_folder="data/", split_name="train")
test_dataset = MyDataset(base_folder="data/", split_name="test")

bs = 10
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, num_workers=0, shuffle=True,
                                              generator=torch.Generator(device=device))
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=4, shuffle=True,
                                             generator=torch.Generator(device=device))

In [7]:
@dataclass
class InputParameters:
    model: nn.Module
    optimizer: torch.optim.Optimizer
    scheduler: torch.optim.lr_scheduler.ExponentialLR
def save_checkpoint(model_input, optimizer_input, scheduler_input, filename):
    print("=> Saving checkpoint")
    input_parameters = InputParameters(model_input, optimizer_input, scheduler_input)
    torch.save(input_parameters, filename)
def load_checkpoint(filename):
    print("=> Loading checkpoint")
    input_parameters = torch.load(filename)
    return input_parameters.model, input_parameters.optimizer, input_parameters.scheduler

In [8]:
#CUDA_VISIBLE_DEVICES=0

In [9]:
sweep_config: dict = {
    "project": "licenta",
    "metric":
        {"name": "loss","goal": "minimize"}
    ,
    "method": "grid", # grid/random
     "parameters":
    #     {
    #     "learning_rate": {
    #         "values": [1e-4, 1e-5, 1e-6]
    #         },
    #     "number_of_epochs": {
    #         "values": [8,9,10]
    #         },
    #     },
    None
}
parameters: dict = {
    "learning_rate": {
        "values": [1e-4, 1e-5, 1e-6]
    },
    "number_of_epochs": {
        "values": [50,75,100]
    },
    "batch_sizes": {
        "values": [10, 12, 15],
    }
}
sweep_config["parameters"] = parameters
#sweep_id = wandb.sweep(sweep_config)

In [10]:
def train_one_epoch(loader, model_input, optimizer_input, loss_function, scheduler_input):
  for data, targets in tqdm(loader):
    data, targets = data.to(device), targets.to(device)
    optimizer_input.zero_grad()
    predictions = model_input(data)
    loss = loss_function(predictions[0], targets)

    print(f"Loss: {loss}")
    print(f"Predicted: {predictions}")
    print(f"Real class: {targets}")

    loss.backward()
    optimizer_input.step()

    #wandb.log({"loss": loss.item()})
  #scheduler_input.step()

In [11]:
from tensorflow import Tensor


def get_accuracy(real_target: Tensor, predicted_target: Tensor):
    # noinspection PyTypeChecker
    correctly_predicted = torch.sum(real_target == predicted_target).item()
    all_targets = real_target.numel()
    return correctly_predicted / all_targets


def get_precision(real_target: Tensor, predicted_target: Tensor):
    # noinspection PyTypeChecker
    all_positive_predicted = torch.sum(predicted_target == 2).item() + torch.sum(predicted_target == 3).item()
    # noinspection PyTypeChecker
    true_positive_targets = torch.sum(
        (real_target == predicted_target) &
        (real_target == 2)).item() + torch.sum(
        (real_target == predicted_target) &
        (real_target == 3)).item()
    precision = true_positive_targets / max(all_positive_predicted, 1)
    return precision


def get_recall(real_target: Tensor, predicted_target: Tensor):
    # noinspection PyTypeChecker
    all_positive_targets = torch.sum(real_target == 2).item() + torch.sum(real_target == 3).item()
    # noinspection PyTypeChecker
    true_positive_targets = torch.sum(
        (real_target == predicted_target) &
        (real_target == 2)).item() + torch.sum(
        (real_target == predicted_target) &
        (real_target == 3)).item()
    recall = true_positive_targets / max(all_positive_targets, 1)
    return recall


def get_confusion_matrix(real_target: Tensor, predicted_target: Tensor):
    confusion_matrix = np.empty((4, 4), dtype=np.uint16)
    actionable_predicted_values, benign_predicted_values, \
        cancer_predicted_values, normal_predicted_values = \
        get_predicted_classes_for_all_classes(predicted_target, real_target)
    for i in range(0, 4):
        set_confusion_matrix_values(actionable_predicted_values, benign_predicted_values, cancer_predicted_values,
                                    confusion_matrix, i, normal_predicted_values)
    return confusion_matrix


def set_confusion_matrix_values(actionable_predicted_values, benign_predicted_values, cancer_predicted_values,
                                confusion_matrix, i, normal_predicted_values):
    # noinspection PyTypeChecker
    confusion_matrix[0][i] = torch.sum(normal_predicted_values == i).item()
    # noinspection PyTypeChecker
    confusion_matrix[1][i] = torch.sum(actionable_predicted_values == i).item()
    # noinspection PyTypeChecker
    confusion_matrix[2][i] = torch.sum(benign_predicted_values == i).item()
    # noinspection PyTypeChecker
    confusion_matrix[3][i] = torch.sum(cancer_predicted_values == i).item()


def get_predicted_classes_for_all_classes(predicted_target, real_target):
    normal_predicted_values = get_predicted_class_for_normal_class(predicted_target, real_target)
    actionable_predicted_values = get_predicted_class_for_actionable_class(predicted_target, real_target)
    benign_predicted_values = get_predicted_class_for_benign_class(predicted_target, real_target)
    cancer_predicted_values = get_predicted_class_for_malign_class(predicted_target, real_target)
    return actionable_predicted_values, benign_predicted_values, cancer_predicted_values, normal_predicted_values


def get_predicted_class_for_malign_class(predicted_target, real_target):
    cancer_real_indexes = (real_target == 3).nonzero()
    cancer_predicted_values = predicted_target[cancer_real_indexes]
    return cancer_predicted_values


def get_predicted_class_for_benign_class(predicted_target, real_target):
    benign_real_indexes = (real_target == 2).nonzero()
    benign_predicted_values = predicted_target[benign_real_indexes]
    return benign_predicted_values


def get_predicted_class_for_actionable_class(predicted_target, real_target):
    actionable_real_indexes = (real_target == 1).nonzero()
    actionable_predicted_values = predicted_target[actionable_real_indexes]
    return actionable_predicted_values


def get_predicted_class_for_normal_class(predicted_target, real_target):
    normal_real_indexes = (real_target == 0).nonzero()
    normal_predicted_values = predicted_target[normal_real_indexes]
    return normal_predicted_values

def check_metrics(loader, model_aux):

    with torch.no_grad():
        all_predicted_classes = torch.tensor([]).to(device)
        all_real_classes = torch.tensor([]).to(device)
        for data, targets in tqdm(loader):
            data, real_class = data.to(device), targets.to(device).argmax(dim=1)
            predicted_class = model_aux(data)[0].to(device).argmax(dim=1)

            all_predicted_classes = torch.cat((all_predicted_classes, predicted_class))
            all_real_classes = torch.cat((all_real_classes, real_class))

            accuracy = get_accuracy(all_real_classes, all_predicted_classes)
            precision = get_precision(all_real_classes, all_predicted_classes)
            recall = get_recall(all_real_classes, all_predicted_classes)
            confusion_matrix = get_confusion_matrix(all_real_classes, all_predicted_classes)
    return accuracy, precision, recall, confusion_matrix

In [12]:
def train_with_metrics(model_to_train, optimizer_to_train, scheduler_to_train):
    #run = wandb.init(config=sweep_config)
    loss_function = nn.CrossEntropyLoss(torch.tensor((1, 1, 2, 2)))
    #learning_rate = wandb.config.learning_rate
    #number_of_epochs = wandb.config.number_of_epochs
    number_of_epochs = 50
    for epoch in range(number_of_epochs):
        model_to_train.train()
        print(f"=> Training epoch: {epoch+1}")
        train_one_epoch(train_data_loader, model_to_train, optimizer_to_train, loss_function, scheduler_to_train)
        print("=> Finished training...")
        model_to_train.eval()
        print("=> Calculating metrics...")
        accuracy, precision, recall, confusion_matrix = check_metrics(test_data_loader, model_to_train)
    #wandb.log({"accuracy": accuracy, "precision": precision})
        print(f"Accuracy: {accuracy}")
        print(f"Precision: {precision}")
        print(f"Recall: {recall}")
        save_checkpoint(model_to_train, optimizer_to_train, scheduler_to_train, f"./checkpoints/acc_{epoch}.pth")
        # plt.figure(figsize=(7, 5))
        # sns.heatmap(confusion_matrix, annot=True, cmap='Blues', fmt='g',
        #         xticklabels=['normal', 'actionable', 'benign', 'malign'],
        #         yticklabels=['normal', 'actionable', 'benign', 'malign'])
        # plt.xlabel('Predicted labels')
        # plt.ylabel('True labels')
        # plt.title('Confusion Matrix')
        # plt.show()
    #run.finish()

In [13]:
from model_aux import GoogLeNet

model = GoogLeNet()

#def get_batch_sizes():
    #run = wandb.init(config=sweep_config)
#bs = wandb.config.batch_sizes

#wandb.agent(sweep_id, function=train_with_metrics, count=5)
    #run.finish()
learning_rate = 0.001
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
train_with_metrics(model, optimizer_to_train=optimizer, scheduler_to_train=scheduler)

=> Training epoch: 1


  0%|          | 0/57 [00:00<?, ?it/s]

Loss: 1.670336127281189
Predicted: (tensor([[0.2307, 0.2319, 0.2832, 0.2543],
        [0.2556, 0.2252, 0.2748, 0.2445],
        [0.2697, 0.2344, 0.2676, 0.2284],
        [0.2605, 0.2216, 0.2720, 0.2460],
        [0.2606, 0.2176, 0.2809, 0.2409],
        [0.2588, 0.2275, 0.2816, 0.2321],
        [0.2618, 0.2270, 0.2940, 0.2173],
        [0.2544, 0.2429, 0.2825, 0.2203],
        [0.2561, 0.2268, 0.2724, 0.2447],
        [0.2768, 0.2398, 0.2682, 0.2152]], grad_fn=<SoftmaxBackward0>), tensor([[-0.0194, -0.0468, -0.0224, -0.0089],
        [-0.0940,  0.0869, -0.0932,  0.0372],
        [ 0.0583,  0.0672, -0.0948, -0.0077],
        [-0.0184, -0.1071, -0.0219, -0.1266],
        [-0.0139, -0.1322, -0.0863,  0.0297],
        [-0.0166, -0.1935,  0.0553,  0.0755],
        [-0.0138, -0.1023, -0.1150, -0.0160],
        [ 0.1515, -0.1018, -0.0647,  0.0496],
        [-0.0166, -0.0080, -0.0404,  0.1155],
        [ 0.2398, -0.1878, -0.0703,  0.0416]], grad_fn=<AddmmBackward0>), tensor([[-0.0527,  0.0546,

  2%|▏         | 1/57 [04:13<3:56:09, 253.03s/it]

Loss: 1.654123306274414
Predicted: (tensor([[0.2505, 0.2340, 0.2718, 0.2436],
        [0.2474, 0.2294, 0.2743, 0.2489],
        [0.2806, 0.2034, 0.2783, 0.2377],
        [0.2512, 0.2419, 0.2782, 0.2287],
        [0.2673, 0.2166, 0.2830, 0.2331],
        [0.2804, 0.2247, 0.2700, 0.2249],
        [0.2595, 0.2323, 0.2713, 0.2370],
        [0.2809, 0.2210, 0.2839, 0.2143],
        [0.2459, 0.2301, 0.2767, 0.2472],
        [0.2721, 0.2221, 0.2787, 0.2272]], grad_fn=<SoftmaxBackward0>), tensor([[-2.8159e-02, -3.6450e-02, -5.3179e-02,  2.6307e-01],
        [-8.2034e-02, -2.8332e-02, -9.8023e-02,  5.1501e-02],
        [ 1.2571e-01, -1.6694e-01,  2.2056e-01, -5.2026e-02],
        [ 6.2877e-02,  2.0869e-02,  9.5487e-02,  2.1517e-02],
        [-2.2551e-02, -4.1781e-03,  2.0026e-02,  1.9584e-01],
        [ 3.3476e-03, -1.7130e-01,  1.0720e-01,  6.8158e-02],
        [ 7.1804e-02,  1.1190e-02, -1.1092e-01,  1.4620e-01],
        [-7.3048e-03, -1.0374e-01, -1.6774e-02, -3.1905e-02],
        [ 5.1074e-

  4%|▎         | 2/57 [07:39<3:30:47, 229.96s/it]


KeyboardInterrupt: 

In [None]:
#wandb.agent(sweep_id, function=get_batch_sizes, count=3)