In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, Adam
from utils.utils import *
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, Dataset
import pickle

In [2]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

Using cache found in C:\Users\mateu/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master


In [3]:
class UNet(nn.Module):
    def __init__(self, base_model):
        super(UNet, self).__init__()
        self.model = base_model
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=3)
        self.fc1 = nn.Linear(28900, 512)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(512, 256)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(256, 128)
        self.relu3 = nn.ReLU()
        self.output = nn.Linear(128, 3) 

    def forward(self, x):
        segmented = self.model(x)
        concatenated = torch.cat((x, segmented), dim=1)
        classified = self.maxpool(concatenated)
        classified = classified.view(classified.size(0), -1)
        classified = self.fc1(classified)
        classified = self.relu1(classified)
        classified = self.fc2(classified)
        classified = self.relu2(classified)
        classified = self.fc3(classified)
        classified = self.relu3(classified)
        output = self.output(classified)
        return segmented, output

# new cell

In [2]:
""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


In [3]:

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))
        self.simoid = nn.Sigmoid()

        self.conv1 = nn.Conv2d(1024, 512, kernel_size=3)
        self.conv2 = nn.Conv2d(512, 256, kernel_size=3)
        self.conv3 = nn.Conv2d(256, 128, kernel_size=3)
        self.max_pool = nn.MaxPool2d(3)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(1152, 256)
        self.linear2 = nn.Linear(256, 64)
        self.linear3 = nn.Linear(64, 16)
        self.linear4 = nn.Linear(16, 3)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        logits = self.simoid(logits)

        x_class = self.conv1(x5)
        x_class = self.conv2(x_class)
        x_class = self.conv3(x_class)
        x_class = self.max_pool(x_class)
        x_class = self.flatten(x_class)
        x_class = torch.relu(self.linear1(x_class))
        x_class = torch.relu(self.linear2(x_class))
        x_class = torch.relu(self.linear3(x_class))
        x_class = self.linear4(x_class)
        return logits, x_class

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [6]:
unet = UNet(3, 1)
print(unet)

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Example usage:
total_params = count_parameters(unet)
print("Total trainable parameters: ", total_params)

Total trainable parameters:  37544388


In [9]:
NUM_EPOCHS = 1
BATCH_SIZE = 16

In [10]:
benign_images, benign_masks = read_data(directory = "data//benign", target_size=(256, 256))
normal_images, normal_masks = read_data(directory = "data//normal", target_size=(256, 256))
malignant_images, malignant_masks = read_data(directory = "data//malignant", target_size=(256, 256))

benign_masks = np.reshape(benign_masks, (benign_masks.shape[0], 1, benign_masks.shape[1], benign_masks.shape[2]))
normal_masks = np.reshape(normal_masks, (normal_masks.shape[0], 1, normal_masks.shape[1], normal_masks.shape[2]))
malignant_masks = np.reshape(malignant_masks, (malignant_masks.shape[0], 1, malignant_masks.shape[1], malignant_masks.shape[2]))

benign_images = np.transpose(benign_images, (0, 3, 1, 2))
normal_images = np.transpose(normal_images, (0, 3, 1, 2))
malignant_images = np.transpose(malignant_images, (0, 3, 1, 2))

benign_values = np.array([1, 0, 0], dtype = np.float32)
normal_values = np.array([0, 1, 0], dtype = np.float32)
malignant_values = np.array([0, 0, 1], dtype = np.float32)
num_benign = len(benign_images)
num_normal = len(normal_images)
num_malignant = len(malignant_images)
y_classification = []
y_segmentation = []
X = []

for i in range(num_benign):
    y_classification.append(benign_values)
for i in range(num_normal):
    y_classification.append(normal_values)
for i in range(num_malignant):
    y_classification.append(malignant_values)
y_classification = np.array(y_classification, dtype = np.float32)

benign_masks[benign_masks>0] = 1
normal_masks[normal_masks>0] = 1
malignant_masks[malignant_masks>0] = 1
benign_images = benign_images.astype(dtype = np.float32)
normal_images = normal_images.astype(dtype = np.float32)
malignant_images = malignant_images.astype(dtype = np.float32)
benign_images /= 255
normal_images /= 255
malignant_images /= 255

for img in benign_images:
    X.append(img)
for img in normal_images:
    X.append(img)
for img in malignant_images:
    X.append(img)
X = np.array(X, dtype=np.float32)

for img in benign_masks:
    y_segmentation.append(img)
for img in normal_masks:
    y_segmentation.append(img)
for img in malignant_masks:
    y_segmentation.append(img)
y_segmentation = np.array(y_segmentation, dtype=np.float32)

num_samples = len(y_segmentation)
perm = np.random.permutation(num_samples)

X = X[perm]
y_classification = y_classification[perm]
y_segmentation = y_segmentation[perm]

total_samples = len(X)
train_size = int(0.7 * total_samples)
val_size = int(0.2 * total_samples)

X_train = X[:train_size]
y_classification_train = y_classification[:train_size]
y_segmentation_train = y_segmentation[:train_size]

X_val = X[train_size:train_size + val_size]
y_classification_val = y_classification[train_size:train_size + val_size]
y_segmentation_val = y_segmentation[train_size:train_size + val_size]

X_test = X[train_size + val_size:]
y_classification_test = y_classification[train_size + val_size:]
y_segmentation_test = y_segmentation[train_size + val_size:]

In [11]:
class CustomDataset(Dataset):
    def __init__(self, X, y1, y2):
        self.X = X
        self.y1 = y1
        self.y2 = y2

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y1[idx], self.y2[idx]

In [12]:
X_train = torch.Tensor(X_train)
y_classification_train = torch.Tensor(y_classification_train)
y_segmentation_train = torch.Tensor(y_segmentation_train)

X_val = torch.Tensor(X_val)
y_classification_val = torch.Tensor(y_classification_val)
y_segmentation_val = torch.Tensor(y_segmentation_val)

X_test = torch.Tensor(X_test)
y_classification_test = torch.Tensor(y_classification_test)
y_segmentation_test = torch.Tensor(y_segmentation_test)

train_dataset = CustomDataset(X_train, y_classification_train, y_segmentation_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
val_dataset = CustomDataset(X_val, y_classification_val, y_segmentation_val)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataset = CustomDataset(X_test, y_classification_test, y_segmentation_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [13]:
y_segmentation_train.shape
y_classification.shape

(780, 3)

In [15]:
unet = UNet(3, 1)
unet.train()

optimizer = Adam(unet.parameters(), lr = 1e-4)
segmentation_loss = torch.nn.BCEWithLogitsLoss()
classification_loss = nn.CrossEntropyLoss()

classification_loss_train = []
segmentation_loss_train = []
classification_loss_val = []
segmentation_loss_val = []
classification_loss_test = []
segmentation_loss_test = []
accuracy_list = []

for i in range(NUM_EPOCHS):
    unet.train()
    train_loss_class = []
    train_loss_seg = []
    val_loss_class = []
    val_loss_seg = []
    test_loss_class = []
    test_loss_seg = []
    accuracy = []

    for batch in train_loader:
        X_data, y_class, y_seg = batch
        optimizer.zero_grad()
        segmentation_output, classification_output = unet(X_data)
        segmentation_output = torch.sigmoid(segmentation_output)
        classification_output = torch.softmax(classification_output, dim=1)
        seg_loss = segmentation_loss(segmentation_output, y_seg)
        class_loss = classification_loss(classification_output, y_class)
        loss = seg_loss + class_loss
        loss.backward()
        optimizer.step()
        train_loss_class.append(class_loss.item())
        train_loss_seg.append(seg_loss.item())
        predicted_classes = torch.argmax(classification_output, dim=1)
        temp_class = torch.argmax(y_class, dim=1)
        accuracy.append((predicted_classes == temp_class).float().mean().item())

    accuracy = sum(accuracy)/len(accuracy)
    accuracy_list.append(accuracy)
    classification_loss_train.append(sum(train_loss_class) / len(train_loss_class))
    segmentation_loss_train.append(sum(train_loss_seg) / len(train_loss_seg))

    unet.eval()
    with torch.no_grad():
        val_loss_class = []
        val_loss_seg = []
        for batch in val_loader:
            X_val, y_class_val, y_seg_val = batch
            segmentation_output_val, classification_output_val = unet(X_val)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_val = segmentation_loss(segmentation_output_val, y_seg_val)
            class_loss_val = classification_loss(classification_output_val, y_class_val)
            val_loss_class.append(class_loss_val.item())
            val_loss_seg.append(seg_loss_val.item())
        
        classification_loss_val.append(sum(val_loss_class) / len(val_loss_class))
        segmentation_loss_val.append(sum(val_loss_seg) / len(val_loss_seg))
    
    with torch.no_grad():
        test_loss_class = []
        test_loss_seg = []
        for batch in test_loader:
            X_test, y_class_test, y_seg_test = batch
            segmentation_output_test, classification_output_test = unet(X_test)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_test = segmentation_loss(segmentation_output_test, y_seg_test)
            class_loss_test = classification_loss(classification_output_test, y_class_test)
            test_loss_class.append(class_loss_test.item())
            test_loss_seg.append(seg_loss_test.item())
        
        classification_loss_test.append(sum(test_loss_class) / len(test_loss_class))
        segmentation_loss_test.append(sum(test_loss_seg) / len(test_loss_seg))

    print(f"Epoch: {i}, Classification Training Loss: {classification_loss_train[-1]}, Train accuracy: {accuracy_list[-1]}, Segmentation Training Loss: {segmentation_loss_train[-1]}, Classification Validation Loss: {classification_loss_val[-1]}, Segmentation Validation Loss: {segmentation_loss_val[-1]}, Classification Test Loss: {classification_loss_test[-1]}, Segmentation Test Loss: {segmentation_loss_test[-1]}")


Epoch: 0, Classification Training Loss: 1.0129894052233015, Train accuracy: 0.4928571428571429, Segmentation Training Loss: 0.985618497644152, Classification Validation Loss: 2.6705032706260683, Segmentation Validation Loss: 0.8495967745780945, Classification Test Loss: 5.451157522201538, Segmentation Test Loss: 0.838263475894928


In [None]:
torch.save(unet.state_dict(), 'unet1.pth')

In [None]:
lists_to_save = {
    'classification_loss_train': classification_loss_train,
    'segmentation_loss_train': segmentation_loss_train,
    'classification_loss_val': classification_loss_val,
    'segmentation_loss_val': segmentation_loss_val,
    'classification_loss_test': classification_loss_test,
    'segmentation_loss_test': segmentation_loss_test
}

with open('unet_results_1.pkl', 'wb') as f:
    pickle.dump(lists_to_save, f)

In [None]:
unet = UNet(model)
unet.train()

optimizer = Adam(unet.parameters(), lr = 1e-4)
segmentation_loss = torch.nn.MSELoss()
classification_loss = nn.CrossEntropyLoss()

classification_loss_train = []
segmentation_loss_train = []
classification_loss_val = []
segmentation_loss_val = []
classification_loss_test = []
segmentation_loss_test = []
accuracy_list = []

for i in range(NUM_EPOCHS):
    unet.train()
    train_loss_class = []
    train_loss_seg = []
    val_loss_class = []
    val_loss_seg = []
    test_loss_class = []
    test_loss_seg = []
    accuracy = []

    for batch in train_loader:
        X_data, y_class, y_seg = batch
        optimizer.zero_grad()
        segmentation_output, _ = unet(X_data)
        segmentation_output = torch.sigmoid(segmentation_output)
        seg_loss = segmentation_loss(segmentation_output, y_seg)
        seg_loss.backward()
        optimizer.step()

        optimizer.zero_grad()
        _,classification_output  = unet(X_data)
        classification_output = torch.softmax(classification_output, dim=1)
        class_loss = classification_loss(classification_output, y_class)
        class_loss.backward()
        optimizer.step()
        train_loss_class.append(class_loss.item())
        train_loss_seg.append(seg_loss.item())
        temp_class = torch.argmax(y_class, dim=1)
        accuracy.append((predicted_classes == temp_class).float().mean().item())

    accuracy = sum(accuracy)/len(accuracy)
    accuracy_list.append(accuracy)
    classification_loss_train.append(sum(train_loss_class) / len(train_loss_class))
    segmentation_loss_train.append(sum(train_loss_seg) / len(train_loss_seg))

    unet.eval()
    with torch.no_grad():
        val_loss_class = []
        val_loss_seg = []
        for batch in val_loader:
            X_val, y_class_val, y_seg_val = batch
            segmentation_output_val, classification_output_val = unet(X_val)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_val = segmentation_loss(segmentation_output_val, y_seg_val)
            class_loss_val = classification_loss(classification_output_val, y_class_val)
            val_loss_class.append(class_loss_val.item())
            val_loss_seg.append(seg_loss_val.item())
        
        classification_loss_val.append(sum(val_loss_class) / len(val_loss_class))
        segmentation_loss_val.append(sum(val_loss_seg) / len(val_loss_seg))
    
    with torch.no_grad():
        test_loss_class = []
        test_loss_seg = []
        for batch in test_loader:
            X_test, y_class_test, y_seg_test = batch
            segmentation_output_test, classification_output_test = unet(X_test)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_test = segmentation_loss(segmentation_output_test, y_seg_test)
            class_loss_test = classification_loss(classification_output_test, y_class_test)
            test_loss_class.append(class_loss_test.item())
            test_loss_seg.append(seg_loss_test.item())
        
        classification_loss_test.append(sum(test_loss_class) / len(test_loss_class))
        segmentation_loss_test.append(sum(test_loss_seg) / len(test_loss_seg))

    print(f"Epoch: {i}, Classification Training Loss: {classification_loss_train[-1]}, Train accuracy: {accuracy_list[-1]}, Segmentation Training Loss: {segmentation_loss_train[-1]}, Classification Validation Loss: {classification_loss_val[-1]}, Segmentation Validation Loss: {segmentation_loss_val[-1]}, Classification Test Loss: {classification_loss_test[-1]}, Segmentation Test Loss: {segmentation_loss_test[-1]}")


In [None]:
torch.save(unet.state_dict(), 'unet2.pth')

In [None]:
lists_to_save = {
    'classification_loss_train': classification_loss_train,
    'segmentation_loss_train': segmentation_loss_train,
    'classification_loss_val': classification_loss_val,
    'segmentation_loss_val': segmentation_loss_val,
    'classification_loss_test': classification_loss_test,
    'segmentation_loss_test': segmentation_loss_test
}

with open('unet_results_2.pkl', 'wb') as f:
    pickle.dump(lists_to_save, f)

In [None]:
b1 = 0.8
b2 = 0.85
unet = UNet(model)
unet.train()

optimizer = Adam(unet.parameters(), lr = 1e-3, betas=(b1, b2), weight_decay=0.0005)
segmentation_loss = torch.nn.MSELoss()
classification_loss = nn.CrossEntropyLoss()

classification_loss_train = []
segmentation_loss_train = []
classification_loss_val = []
segmentation_loss_val = []
classification_loss_test = []
segmentation_loss_test = []
accuracy_list = []

for i in range(NUM_EPOCHS):
    unet.train()
    train_loss_class = []
    train_loss_seg = []
    val_loss_class = []
    val_loss_seg = []
    test_loss_class = []
    test_loss_seg = []
    accuracy = []

    for batch in train_loader:
        X_data, y_class, y_seg = batch
        optimizer.zero_grad()
        segmentation_output, classification_output = unet(X_data)
        segmentation_output = torch.sigmoid(segmentation_output)
        classification_output = torch.softmax(classification_output, dim=1)
        seg_loss = segmentation_loss(segmentation_output, y_seg)
        class_loss = classification_loss(classification_output, y_class)
        loss = seg_loss + class_loss
        loss.backward()
        optimizer.step()
        train_loss_class.append(class_loss.item())
        train_loss_seg.append(seg_loss.item())
        temp_class = torch.argmax(y_class, dim=1)
        accuracy.append((predicted_classes == temp_class).float().mean().item())

    accuracy = sum(accuracy)/len(accuracy)
    accuracy_list.append(accuracy)
    classification_loss_train.append(sum(train_loss_class) / len(train_loss_class))
    segmentation_loss_train.append(sum(train_loss_seg) / len(train_loss_seg))

    unet.eval()
    with torch.no_grad():
        val_loss_class = []
        val_loss_seg = []
        for batch in val_loader:
            X_val, y_class_val, y_seg_val = batch
            segmentation_output_val, classification_output_val = unet(X_val)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_val = segmentation_loss(segmentation_output_val, y_seg_val)
            class_loss_val = classification_loss(classification_output_val, y_class_val)
            val_loss_class.append(class_loss_val.item())
            val_loss_seg.append(seg_loss_val.item())
        
        classification_loss_val.append(sum(val_loss_class) / len(val_loss_class))
        segmentation_loss_val.append(sum(val_loss_seg) / len(val_loss_seg))
    
    with torch.no_grad():
        test_loss_class = []
        test_loss_seg = []
        for batch in test_loader:
            X_test, y_class_test, y_seg_test = batch
            segmentation_output_test, classification_output_test = unet(X_test)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_test = segmentation_loss(segmentation_output_test, y_seg_test)
            class_loss_test = classification_loss(classification_output_test, y_class_test)
            test_loss_class.append(class_loss_test.item())
            test_loss_seg.append(seg_loss_test.item())
        
        classification_loss_test.append(sum(test_loss_class) / len(test_loss_class))
        segmentation_loss_test.append(sum(test_loss_seg) / len(test_loss_seg))

    print(f"Epoch: {i}, Classification Training Loss: {classification_loss_train[-1]}, Train accuracy: {accuracy_list[-1]}, Segmentation Training Loss: {segmentation_loss_train[-1]}, Classification Validation Loss: {classification_loss_val[-1]}, Segmentation Validation Loss: {segmentation_loss_val[-1]}, Classification Test Loss: {classification_loss_test[-1]}, Segmentation Test Loss: {segmentation_loss_test[-1]}")


In [None]:
torch.save(unet.state_dict(), 'unet3.pth')

In [None]:
lists_to_save = {
    'classification_loss_train': classification_loss_train,
    'segmentation_loss_train': segmentation_loss_train,
    'classification_loss_val': classification_loss_val,
    'segmentation_loss_val': segmentation_loss_val,
    'classification_loss_test': classification_loss_test,
    'segmentation_loss_test': segmentation_loss_test
}

with open('unet_results_3.pkl', 'wb') as f:
    pickle.dump(lists_to_save, f)

In [None]:
unet = UNet(model)
unet.train()

optimizer = AdamW(unet.parameters(), lr = 1e-4)
segmentation_loss = torch.nn.MSELoss()
classification_loss = nn.CrossEntropyLoss()

classification_loss_train = []
segmentation_loss_train = []
classification_loss_val = []
segmentation_loss_val = []
classification_loss_test = []
segmentation_loss_test = []
accuracy_list = []

for i in range(NUM_EPOCHS):
    unet.train()
    train_loss_class = []
    train_loss_seg = []
    val_loss_class = []
    val_loss_seg = []
    test_loss_class = []
    test_loss_seg = []
    accuracy = []

    for batch in train_loader:
        X_data, y_class, y_seg = batch
        optimizer.zero_grad()
        segmentation_output, classification_output = unet(X_data)
        segmentation_output = torch.sigmoid(segmentation_output)
        classification_output = torch.softmax(classification_output, dim=1)
        seg_loss = segmentation_loss(segmentation_output, y_seg)
        class_loss = classification_loss(classification_output, y_class)
        loss = seg_loss + class_loss
        loss.backward()
        optimizer.step()
        train_loss_class.append(class_loss.item())
        train_loss_seg.append(seg_loss.item())
        temp_class = torch.argmax(y_class, dim=1)
        accuracy.append((predicted_classes == temp_class).float().mean().item())

    accuracy = sum(accuracy)/len(accuracy)
    accuracy_list.append(accuracy)
    classification_loss_train.append(sum(train_loss_class) / len(train_loss_class))
    segmentation_loss_train.append(sum(train_loss_seg) / len(train_loss_seg))

    unet.eval()
    with torch.no_grad():
        val_loss_class = []
        val_loss_seg = []
        for batch in val_loader:
            X_val, y_class_val, y_seg_val = batch
            segmentation_output_val, classification_output_val = unet(X_val)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_val = segmentation_loss(segmentation_output_val, y_seg_val)
            class_loss_val = classification_loss(classification_output_val, y_class_val)
            val_loss_class.append(class_loss_val.item())
            val_loss_seg.append(seg_loss_val.item())
        
        classification_loss_val.append(sum(val_loss_class) / len(val_loss_class))
        segmentation_loss_val.append(sum(val_loss_seg) / len(val_loss_seg))
    
    with torch.no_grad():
        test_loss_class = []
        test_loss_seg = []
        for batch in test_loader:
            X_test, y_class_test, y_seg_test = batch
            segmentation_output_test, classification_output_test = unet(X_test)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_test = segmentation_loss(segmentation_output_test, y_seg_test)
            class_loss_test = classification_loss(classification_output_test, y_class_test)
            test_loss_class.append(class_loss_test.item())
            test_loss_seg.append(seg_loss_test.item())
        
        classification_loss_test.append(sum(test_loss_class) / len(test_loss_class))
        segmentation_loss_test.append(sum(test_loss_seg) / len(test_loss_seg))

    print(f"Epoch: {i}, Classification Training Loss: {classification_loss_train[-1]}, Train accuracy: {accuracy_list[-1]}, Segmentation Training Loss: {segmentation_loss_train[-1]}, Classification Validation Loss: {classification_loss_val[-1]}, Segmentation Validation Loss: {segmentation_loss_val[-1]}, Classification Test Loss: {classification_loss_test[-1]}, Segmentation Test Loss: {segmentation_loss_test[-1]}")


In [None]:
torch.save(unet.state_dict(), 'unet4.pth')

In [None]:
lists_to_save = {
    'classification_loss_train': classification_loss_train,
    'segmentation_loss_train': segmentation_loss_train,
    'classification_loss_val': classification_loss_val,
    'segmentation_loss_val': segmentation_loss_val,
    'classification_loss_test': classification_loss_test,
    'segmentation_loss_test': segmentation_loss_test
}

with open('unet_results_4.pkl', 'wb') as f:
    pickle.dump(lists_to_save, f)