In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=False)
!pip install wandb -qU
import sys
sys.path.append("/content/drive/MyDrive/ImportScripts")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import wandb
from datetime import datetime
import torch.optim as optim
from sklearn.model_selection import KFold

import derm7pt_data
from derm7pt_data import Derm7pt_data
from Model import Simple_CNN_Net, Simple_CNN_PerfectConcepts, Concept_To_Label_Net

from importlib import reload

In [None]:
wandb.login(key="02e7328f3dec3b552cb764d0b265fbe0a90757a7")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
reload(derm7pt_data)
#reload(Model)

#Data loading
random_state = 42
torch.manual_seed(random_state)
path = os.path.normpath("/content/drive/MyDrive/Derm7pt/")  #local path: os.path.normpath('Data//Derm7pt')

derm7pt = Derm7pt_data(path)
metadata = derm7pt.metadata
print("Data shape:", metadata.shape)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("device:", device)

In [None]:
metadata.columns

Index(['case_num', 'diagnosis', 'seven_point_score', 'pigment_network',
       'streaks', 'pigmentation', 'regression_structures', 'dots_and_globules',
       'blue_whitish_veil', 'vascular_structures',
       'level_of_diagnostic_difficulty', 'elevation', 'location', 'sex',
       'clinic', 'derm', 'diagnosis_num', 'is_cancer', 'abbrevs', 'info',
       'pigment_network_num', 'pigment_network_score', 'streaks_num',
       'streaks_score', 'pigmentation_num', 'pigmentation_score',
       'regression_structures_num', 'regression_structures_score',
       'dots_and_globules_num', 'dots_and_globules_score',
       'blue_whitish_veil_num', 'blue_whitish_veil_score',
       'vascular_structures_num', 'vascular_structures_score'],
      dtype='object')

In [None]:
#Help functions to calculate the majority class baseline
def majority_class_baseline(val_idx, mode_txt=""):
    print("start ", mode_txt, " baseline: ", datetime.now())
    majority_loader = DataLoader(
        dataset=derm7pt,
        batch_size=999999,
        sampler=torch.utils.data.SubsetRandomSampler(val_idx),
    )
    for i, batch in enumerate(majority_loader, 0):
        inputs, labels, concept_labels = batch
        baseline, baseline_accuracy = majority_class_accuracy_by_labels(labels)

        #concept baseline
        concept_baseline = 0
        concept_outputs = torch.zeros(len(labels), num_concepts)
        concept_baseline_accuracy = ((concept_outputs == concept_labels).sum().item()) / (len(labels)*num_concepts)

        print("end ", mode_txt, " baseline:   ", datetime.now(), ", baseline: ", baseline, " percent ",  baseline_accuracy, " concept_baseline: ", concept_baseline, " concept_", mode_txt, "_baseline: ", concept_baseline_accuracy)
        return baseline_accuracy, concept_baseline_accuracy

def majority_class_accuracy_by_labels(true_labels):
    # Find the most frequent class in the training set
    elems, counts = true_labels.unique(return_counts=True)
    majority_count = counts[counts.argmax()]
    majority_class = elems[counts.argmax()]
    #predictions = torch.full_like(true_labels, majority_class)
    accuracy = majority_count / len(true_labels)
    return majority_class, accuracy

In [None]:
#Training the model
# hyperparameters
n_epochs = 50
learning_rate = 0.0002
n_folds = 8
batch_size = 8
learn_concepts = True   #Defines if loss should be calculated for concepts

num_classes = derm7pt.diagnosis_mapping[derm7pt.model_columns["label"]].nunique()
num_concepts = len(derm7pt.concepts_mapping)
criterion_concept = nn.BCELoss()
criterion = nn.CrossEntropyLoss() #Categorical crossEntropyLoss

#split up init form main training loop, because of faulty display of print statements during training
wandb.init(
    # set the wandb project where this run will be logged
    project= "PracticalWork",

    # track hyperparameters and run metadata
    config={
    "learning_rate": learning_rate,
    "architecture": "Simple_CNN_Net",
    "dataset": "derm7pt",
    "labels": derm7pt.model_columns["label"],
    "epochs": n_epochs,
    "batch_size": batch_size,
    "n_folds": n_folds,
    "device": device,
    "num_classes": num_classes,
    "num_concepts": num_concepts,
    "learn_concepts": learn_concepts,
    "random_state": random_state,
    },
    name="run"+str(datetime.now())
)

[34m[1mwandb[0m: Currently logged in as: [33mtraglert[0m ([33mnlp_ass3[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# Training loop
kf = KFold(n_splits=n_folds, shuffle=True, random_state=random_state)
for fold, (train_idx, val_idx) in enumerate(kf.split(derm7pt.metadata)):
    #get the majority class of the validation and test set
    simple_val_baseline, concept_val_baseline = majority_class_baseline(val_idx, "validation")
    #left out for performance reasons simple_train_baseline, concept_train_baseline = majority_class_baseline(train_idx, "train")

    train_loader = DataLoader(
        dataset=derm7pt,
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(train_idx),
    )
    val_loader = DataLoader(
        dataset=derm7pt,
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(val_idx),
    )

    # Instantiate the model
    model = Simple_CNN_Net(num_classes=num_classes,num_concepts=num_concepts, image_size=derm7pt.image_size)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(n_epochs):
        running_loss = 0.0
        running_loss_concepts = 0.0
        i = 0
        train_total_correct = 0
        train_concepts_total_correct = 0
        model.train()
        for i, batch in enumerate(train_loader, 0):
            inputs, labels, concept_labels = batch
            #one hot encoding of the label
            hot_labels = torch.eye(num_classes)[labels.squeeze().int()]
            inputs, hot_labels, concept_labels = inputs.to(device), hot_labels.to(device), concept_labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()
            # forward pass for both concepts and outputs
            concept_outputs, outputs = model(inputs)
            if learn_concepts:
                loss_concepts = criterion_concept(concept_outputs, concept_labels)
                loss_concepts.backward(retain_graph=True)
                # statistics: average loss
                running_loss_concepts += loss_concepts.item()

                # concept accuracy
                train_concepts_total_correct += (concept_outputs.round() == concept_labels).sum().item()
            loss_outputs = criterion(outputs, hot_labels)
            loss_outputs.backward()
            optimizer.step()
            # statistics: average loss
            running_loss += loss_outputs.item()

            # train accuracy
            _, predicted = torch.max(outputs, 1)
            train_total_correct += (predicted == hot_labels.argmax(dim=1)).sum().item()
            #if i%80 == 0:
                #print("i ", i, predicted, "\n", labels, "\n", outputs)


        running_loss /= (i+1)
        running_loss_concepts /= (i+1)
        train_accuracy = train_total_correct / len(train_idx)
        concept_train_accuracy = train_concepts_total_correct / (len(train_idx)*num_concepts)


        # Validation
        model.eval()
        correct = 0
        concept_correct = 0
        total = 0
        with torch.no_grad():
            for i, batch in enumerate(val_loader, 0):
                inputs, labels, concept_labels = batch
                inputs, labels, concept_labels = inputs.to(device), labels.to(device), concept_labels.to(device)
                concept_outputs, outputs = model(inputs)
                outputs = outputs.argmax(dim=1)
                total += labels.size(0)
                correct += (outputs == labels).sum().item()
                concept_correct += (concept_outputs.round() == concept_labels).sum().item()
        val_accuracy = correct/total
        concept_val_accuracy = concept_correct/(total*num_concepts)
        wandb.log({"loss": running_loss, "train_accuracy": train_accuracy, "concept_loss:": running_loss_concepts, "concept_train_accuracy": concept_train_accuracy, "validation_accuracy": val_accuracy, "concept_validation_accuracy": concept_val_accuracy})
        print('[%d, %3d] loss: %.4f, val_accuracy: %.4f, time: %s' % (epoch + 1, i + 1, running_loss, val_accuracy, datetime.now().strftime('%Y-%m-%d %H:%M:%S')))

    #Only one fold for performance reasons
    break

wandb.finish()
print('Finished Training')

start  validation  baseline:  2024-11-06 16:49:48.506588
end  validation  baseline:    2024-11-06 16:51:39.466695 , baseline:  tensor(1)  percent  tensor(0.5984)  concept_baseline:  0  concept_ validation _baseline:  0.7716535433070866
[1,  16] loss: 1.5412, val_accuracy: 0.5984, time: 2024-11-06 17:04:13
[2,  16] loss: 1.5110, val_accuracy: 0.5984, time: 2024-11-06 17:04:35
[3,  16] loss: 1.4903, val_accuracy: 0.5984, time: 2024-11-06 17:04:55


KeyboardInterrupt: 