In [1]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from network import Network
from network import Trainer
from preprocessing import Preprocessing
from progressbar import ProgressBar

KeyboardInterrupt: 

## Čišćenje dataSet-ova

In [None]:
preprocessing = Preprocessing()

In [None]:
preprocessing.CheckImages("training/benign/")
preprocessing.CheckImages("training/malignant/")

preprocessing.CheckImages("validation/benign/")
preprocessing.CheckImages("validation/malignant/")

preprocessing.CheckImages("testing/benign/")
preprocessing.CheckImages("testing/malignant/")

## Učitavanje podataka u vidu Tensora

In [None]:
trainingData = preprocessing.DataReader("training")
validationData = preprocessing.DataReader("validation")
testingData = preprocessing.DataReader("testing")

In [None]:
trainingIndicies = preprocessing.AppendIndicies(trainingData)
validationIndicies = preprocessing.AppendIndicies(validationData)
testingIndicies = preprocessing.AppendIndicies(testingData)

In [None]:
batchSize = 64

trainingDataLoader = preprocessing.DataLoader(
    trainingIndicies,
    trainingData,
    batchSize = batchSize)

validationDataLoader = preprocessing.DataLoader(
    validationIndicies,
    validationData,
    batchSize = batchSize)

testingDataLoader = preprocessing.DataLoader(
    testingIndicies,
    testingData,
    batchSize = batchSize)

## Prikaz slučajno izabranih snimaka iz sva tri dataSeta 
### Prva vizuelizacija je vizuelizacija dataSeta za trening, zatim slijedi validacioni dataSet i na kraju dataSet za testiranje

In [None]:
preprocessing.ShowGrid(trainingDataLoader)

In [None]:
preprocessing.ShowGrid(validationDataLoader)

In [None]:
preprocessing.ShowGrid(testingDataLoader)

In [None]:
network = Network()
trainer = Trainer()
# epoch = 0
# optimizer = optim.Adam(network.parameters(), lr = 0.001)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 2)

# checkpoint = torch.load(PATH)



In [None]:
epoch = 0
optimizer = optim.Adam(network.parameters(), lr = 0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 2)


epochs = 30
trainingArray = []
learningRate = np.array([])

for epoch in range(epoch, epochs):
    
    PATH = f"../checkpoints/model{epoch}.pth"
    trainingLoss, trainingAccuracy = 0, 0 
    valLoss, valCorret = 0, 0
    network.train()
                       
    loop = tqdm(enumerate(trainingDataLoader), total = len(trainingDataLoader), leave = False)
    for batchIndex, (images, labels) in loop:
        
        predictions = network(images)
        loss = F.cross_entropy(predictions, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        trainingLoss += loss.item()
        trainingAccuracy += trainer.Correct(predictions, labels)
        
        loop.set_description(f"Epoch [{epoch}/{epochs}]")
        loop.set_postfix(loss = loss.item(), acc = torch.rand(1).item())
    
    array.append([epoch , trainingAccuracy, trainingLoss])
    print(f"Epoch {epoch}\t Total Training Loss: {trainingLoss}\t Total Training Accuracy: {trainingAccuracy}\t Mean Training: {trainingAccuracy / len(trainingData)}")
    # Validation     
    valLoss, valCorret = trainer.Validation(network, validationDataLoader, epoch, epochs)
    scheduler.step(valLoss) # optimizing the learning rate
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': network.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'scheduler': scheduler.state_dict()
    }, PATH)    
    
    train = open("./data/training.txt", "a")
    for i in trainingArray:
        train.write(i)
    train.close()

In [None]:
# PATH = "./model/model.pth"
# torch.save(network.state_dict(), PATH)

In [None]:
plt.figure(figsize = (8,6))
plt.plot(avgEpochAccuracy, label = "Training Accuracy", color = "#001024")
plt.plot(valEpochAccuracy, label = "Validation Accuracy", color = "#FF800B")
plt.title("Training and Validation Accuracy", fontsize = 14, fontweight = "bold")
plt.xlabel("Epoch", fontsize = 12, fontweight = "bold")
plt.ylabel("Value", fontsize = 12, fontweight = "bold")
plt.xticks([0, 1, 2, 3, 4, 5])
plt.legend(loc = "lower right")

sns.set_style("white")
sns.despine(top = True, right = True)

plt.show()

In [None]:
plt.figure(figsize = (8,6))
plt.plot(avgEpochLoss, label = "Training Loss", color = "#001024")
plt.plot(valEpochLoss, label = "Validation Loss", color = "#FF800B")
plt.title("Training and Validation Loss", fontsize = 14, fontweight = "bold")
plt.xlabel("Epoch", fontsize = 12, fontweight = "bold")
plt.ylabel("Value", fontsize = 12, fontweight = "bold")
plt.xticks([0, 1, 2, 3, 4, 5])

sns.set_style("white")
sns.despine(top = True, right = True)

plt.legend()
plt.show()

In [None]:
trainCorrect = trainingAccuracy
trainIncorrect = len(trainingData.targets) - trainingAccuracy

In [None]:
valCorrect = validationAccuracy
valIncorrect = len(validationData.targets) - validationAccuracy

In [None]:
labels = ['Training', 'Validation']
plt.figure(figsize = (6,5))

sns.set_style("darkgrid")

plt.bar(
    x = labels[0],
    height = trainCorrect,
    color = "#001024",
    width = 0.12,
    align = "center",
    label = "Correct"
)

plt.bar(
    x = labels[1],
    height = valCorrect,
    color = "#001024",
    width = 0.12,
    align = "center"
)

plt.bar(
    x = labels[0],
    height = trainIncorrect,
    bottom = trainCorrect,
    color = "#FF800B",
    width = 0.12,
    align = "center"
)

plt.bar(
    x = labels[1],
    height = valIncorrect,
    bottom = valCorrect,
    color = "#FF800B",
    width = 0.12,
    label = "Incorrect",
    align = "center"
)

plt.title("Number of Correct and Incorrect Predictions\n in the Training and Validation Phase", fontsize = 14, fontweight = "bold")
plt.xlabel("Phase", fontsize = 12, fontweight = "bold")
plt.ylabel("Number of Values", fontsize = 12, fontweight = "bold")

plt.legend()
plt.show()
