In [1]:
import unet
import image_processor
import torch
import os
import zipfile
from datetime import datetime

In [2]:
def modelWriter(model, path : str):
    torch.save(model.state_dict(), path)
    
#thanks chatGPT
def zip_file(file_path, zip_path):
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        zipf.write(file_path, arcname=os.path.basename(file_path))
    os.remove(file_path)


In [3]:
#imageDir = "Data_Renamed\processed_data\\test_data_downsampled_16x\images_rename"
#labelDir = "Data_Renamed\processed_data\\test_data_downsampled_16x\labels_rename"
imageDir = "Data_Renamed\\raw_data\\train_val\dataset_2\images_rename"
labelDir = "Data_Renamed\\raw_data\\train_val\dataset_2\labels_rename"
outputPath = "Output_Models"

#labelTransform = torchvision.transforms.Compose([
#    torchvision.transforms.Resize((16, 64))])

name = "unet_norm"

batchSizes = [16]
epochs = [100]
learningRates = [.001]
labelDim = 262144

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

parentOut = f'{outputPath}/{name}/weightsNorm/'



In [4]:
for e in epochs:
    for lr in learningRates:
        for bs in batchSizes:
            model = unet.unet(labelDim=labelDim)
            model.to(device)
            


            trainLoader, valLoader = image_processor.imageDirsToLoaders(imageDir=imageDir, labelDir= labelDir, batchSize= bs)
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            loss = torch.nn.MSELoss()

            print(f'EPOCHS: {e} | LR: {lr} | BATCH: {bs}')
            epochLosses = []
            validationLosses = []

            for i in range(e):
                model.train()
                currentTime = datetime.now()
                formatTime = currentTime.strftime("%Y-%m-%d %H:%M:%S")
                print(f'EPOCH: {i+1} | {formatTime}')
                runningLoss = 0
                epochLoss = 0
                for image, label in trainLoader:
                    image = image.to(device)
                    label = label.to(device)
                    optimizer.zero_grad()
                    outputClasses = model(image)
                    
                    #outputLogits = torch.sigmoid(outputClasses)
                    #labelLogits = torch.sigmoid(label)
                    batchLoss = loss(outputClasses, label)
                    
                    batchLoss.backward()
                    optimizer.step()
                    epochLoss += batchLoss

                currentTime = datetime.now()
                formatTime = currentTime.strftime("%Y-%m-%d %H:%M:%S")
                print(f"EPOCH: {i+1} | LOSS: {epochLoss} | {formatTime}")

                epochLosses.append(epochLoss)
                batchLoss = 0
                epochLoss = 0
            
                validationLoss = 0

                model.eval()
                with torch.no_grad():
                    for image, label in valLoader:
                        image = image.to(device)
                        label = label.to(device)
                        outputs = model(image)
                        valBatchLoss = loss(outputs, label)
                        validationLoss += valBatchLoss
                
                validationLoss = round(validationLoss.item()/len(valLoader))
                validationLosses.append(validationLoss)
                model.validationLoss = validationLosses
                model.traningLosses = epochLosses

                detailedName = name + "_VL_" + str(validationLoss) + "_E_" + str(e) + "_B_" + str(bs) + "_LR_" + str(lr)
                if not os.path.exists(parentOut):
                    os.makedirs(parentOut)
                writeOut = os.path.join(parentOut, f'{detailedName}.pt')
                modelWriter(model, writeOut)
                print(f'{detailedName} saved')

            
            




EPOCHS: 100 | LR: 0.001 | BATCH: 16
EPOCH: 1 | 2024-03-29 05:47:09
EPOCH: 1 | LOSS: 237020.203125 | 2024-03-29 06:06:48
unet_norm_VL_4771_E_100_B_16_LR_0.001 saved
EPOCH: 2 | 2024-03-29 06:09:50
EPOCH: 2 | LOSS: 204755.46875 | 2024-03-29 06:30:21
unet_norm_VL_4914_E_100_B_16_LR_0.001 saved
EPOCH: 3 | 2024-03-29 06:33:20
EPOCH: 3 | LOSS: 200905.265625 | 2024-03-29 06:52:50
unet_norm_VL_4713_E_100_B_16_LR_0.001 saved
EPOCH: 4 | 2024-03-29 06:55:25
EPOCH: 4 | LOSS: 199982.640625 | 2024-03-29 07:13:14
unet_norm_VL_4786_E_100_B_16_LR_0.001 saved
EPOCH: 5 | 2024-03-29 07:15:51
EPOCH: 5 | LOSS: 205843.3125 | 2024-03-29 07:33:40
unet_norm_VL_4586_E_100_B_16_LR_0.001 saved
EPOCH: 6 | 2024-03-29 07:36:17
EPOCH: 6 | LOSS: 16967018.0 | 2024-03-29 07:54:09
unet_norm_VL_5014_E_100_B_16_LR_0.001 saved
EPOCH: 7 | 2024-03-29 07:56:45
EPOCH: 7 | LOSS: 1871205.125 | 2024-03-29 08:14:37
unet_norm_VL_4858_E_100_B_16_LR_0.001 saved
EPOCH: 8 | 2024-03-29 08:17:13
EPOCH: 8 | LOSS: 341346.1875 | 2024-03-29 08: