In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, Adam
from utils.utils import *
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, Dataset

In [2]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

Using cache found in C:\Users\mateu/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master


In [3]:
class UNet(nn.Module):
    def __init__(self, base_model):
        super(UNet, self).__init__()
        self.model = base_model
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=3)
        self.fc1 = nn.Linear(28900, 512)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(512, 256)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(256, 128)
        self.relu3 = nn.ReLU()
        self.output = nn.Linear(128, 3) 

    def forward(self, x):
        segmented = self.model(x)
        concatenated = torch.cat((x, segmented), dim=1)
        classified = self.maxpool(concatenated)
        classified = classified.view(classified.size(0), -1)
        classified = self.fc1(classified)
        classified = self.relu1(classified)
        classified = self.fc2(classified)
        classified = self.relu2(classified)
        classified = self.fc3(classified)
        classified = self.relu3(classified)
        output = self.output(classified)
        return segmented, output

In [4]:
unet = UNet(model)
print(unet)

UNet(
  (model): UNet(
    (encoder1): Sequential(
      (enc1conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (enc1relu1): ReLU(inplace=True)
      (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (enc1relu2): ReLU(inplace=True)
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (encoder2): Sequential(
      (enc2conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (enc2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (enc2relu1): ReLU(inplace=True)
      (enc2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (enc2norm2): BatchNorm2

In [5]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Example usage:
total_params = count_parameters(unet)
print("Total trainable parameters: ", total_params)

Total trainable parameters:  22724964


In [6]:
NUM_EPOCHS = 100
BATCH_SIZE = 16

In [7]:
benign_images, benign_masks = read_data(directory = "data//benign", target_size=(256, 256))
normal_images, normal_masks = read_data(directory = "data//normal", target_size=(256, 256))
malignant_images, malignant_masks = read_data(directory = "data//malignant", target_size=(256, 256))

benign_masks = np.reshape(benign_masks, (benign_masks.shape[0], 1, benign_masks.shape[1], benign_masks.shape[2]))
normal_masks = np.reshape(normal_masks, (normal_masks.shape[0], 1, normal_masks.shape[1], normal_masks.shape[2]))
malignant_masks = np.reshape(malignant_masks, (malignant_masks.shape[0], 1, malignant_masks.shape[1], malignant_masks.shape[2]))

benign_images = np.transpose(benign_images, (0, 3, 1, 2))
normal_images = np.transpose(normal_images, (0, 3, 1, 2))
malignant_images = np.transpose(malignant_images, (0, 3, 1, 2))

benign_values = np.array([1, 0, 0], dtype = np.float32)
normal_values = np.array([0, 1, 0], dtype = np.float32)
malignant_values = np.array([0, 0, 1], dtype = np.float32)
num_benign = len(benign_images)
num_normal = len(normal_images)
num_malignant = len(malignant_images)
y_classification = []
y_segmentation = []
X = []

for i in range(num_benign):
    y_classification.append(benign_values)
for i in range(num_normal):
    y_classification.append(normal_values)
for i in range(num_malignant):
    y_classification.append(malignant_values)
y_classification = np.array(y_classification, dtype = np.float32)

benign_masks[benign_masks>0] = 1
normal_masks[normal_masks>0] = 1
malignant_masks[malignant_masks>0] = 1
benign_images = benign_images.astype(dtype = np.float32)
normal_images = normal_images.astype(dtype = np.float32)
malignant_images = malignant_images.astype(dtype = np.float32)
benign_images /= 255
normal_images /= 255
malignant_images /= 255

for img in benign_images:
    X.append(img)
for img in normal_images:
    X.append(img)
for img in malignant_images:
    X.append(img)
X = np.array(X, dtype=np.float32)

for img in benign_masks:
    y_segmentation.append(img)
for img in normal_masks:
    y_segmentation.append(img)
for img in malignant_masks:
    y_segmentation.append(img)
y_segmentation = np.array(y_segmentation, dtype=np.float32)

num_samples = len(y_segmentation)
perm = np.random.permutation(num_samples)

X = X[perm]
y_classification = y_classification[perm]
y_segmentation = y_segmentation[perm]

total_samples = len(X)
train_size = int(0.7 * total_samples)
val_size = int(0.2 * total_samples)

X_train = X[:train_size]
y_classification_train = y_classification[:train_size]
y_segmentation_train = y_segmentation[:train_size]

X_val = X[train_size:train_size + val_size]
y_classification_val = y_classification[train_size:train_size + val_size]
y_segmentation_val = y_segmentation[train_size:train_size + val_size]

X_test = X[train_size + val_size:]
y_classification_test = y_classification[train_size + val_size:]
y_segmentation_test = y_segmentation[train_size + val_size:]

In [8]:
class CustomDataset(Dataset):
    def __init__(self, X, y1, y2):
        self.X = X
        self.y1 = y1
        self.y2 = y2

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y1[idx], self.y2[idx]

In [9]:
X_train = torch.Tensor(X_train)
y_classification_train = torch.Tensor(y_classification_train)
y_segmentation_train = torch.Tensor(y_segmentation_train)

X_val = torch.Tensor(X_val)
y_classification_val = torch.Tensor(y_classification_val)
y_segmentation_val = torch.Tensor(y_segmentation_val)

X_test = torch.Tensor(X_test)
y_classification_test = torch.Tensor(y_classification_test)
y_segmentation_test = torch.Tensor(y_segmentation_test)

train_dataset = CustomDataset(X_train, y_classification_train, y_segmentation_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
val_dataset = CustomDataset(X_val, y_classification_val, y_segmentation_val)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataset = CustomDataset(X_test, y_classification_test, y_segmentation_test)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [10]:
y_segmentation_train.shape
y_classification.shape

(780, 3)

In [11]:
unet = UNet(model)
unet.train()

optimizer = Adam(unet.parameters(), lr = 1e-4)
segmentation_loss = nn.CrossEntropyLoss()
classification_loss = nn.CrossEntropyLoss()

classification_loss_train = []
segmentation_loss_train = []
classification_loss_val = []
segmentation_loss_val = []
classification_loss_test = []
segmentation_loss_test = []

for i in range(NUM_EPOCHS):
    unet.train()
    train_loss_class = []
    train_loss_seg = []
    val_loss_class = []
    val_loss_seg = []
    test_loss_class = []
    test_loss_seg = []

    for batch in train_loader:
        X_data, y_class, y_seg = batch
        optimizer.zero_grad()
        segmentation_output, classification_output = unet(X_data)
        segmentation_output = torch.sigmoid(segmentation_output)
        classification_output = torch.softmax(classification_output, dim=1)
        seg_loss = segmentation_loss(segmentation_output, y_seg)
        class_loss = classification_loss(classification_output, y_class)
        loss = seg_loss + class_loss
        loss.backward()
        optimizer.step()
        train_loss_class.append(class_loss.item())
        train_loss_seg.append(seg_loss.item())

    classification_loss_train.append(sum(train_loss_class) / len(train_loss_class))
    segmentation_loss_train.append(sum(train_loss_seg) / len(train_loss_seg))

    unet.eval()
    with torch.no_grad():
        val_loss_class = []
        val_loss_seg = []
        for batch in val_loader:
            X_val, y_class_val, y_seg_val = batch
            segmentation_output_val, classification_output_val = unet(X_val)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_val = segmentation_loss(segmentation_output_val, y_seg_val)
            class_loss_val = classification_loss(classification_output_val, y_class_val)
            val_loss_class.append(class_loss_val.item())
            val_loss_seg.append(seg_loss_val.item())
        
        classification_loss_val.append(sum(val_loss_class) / len(val_loss_class))
        segmentation_loss_val.append(sum(val_loss_seg) / len(val_loss_seg))
    
    with torch.no_grad():
        test_loss_class = []
        test_loss_seg = []
        for batch in test_loader:
            X_test, y_class_test, y_seg_test = batch
            segmentation_output_test, classification_output_test = unet(X_test)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_test = segmentation_loss(segmentation_output_test, y_seg_test)
            class_loss_test = classification_loss(classification_output_test, y_class_test)
            test_loss_class.append(class_loss_test.item())
            test_loss_seg.append(seg_loss_test.item())
        
        classification_loss_test.append(sum(test_loss_class) / len(test_loss_class))
        segmentation_loss_test.append(sum(test_loss_seg) / len(test_loss_seg))

    print(f"Epoch: {i}, Classification Training Loss: {classification_loss_train[-1]}, Segmentation Training Loss: {segmentation_loss_train[-1]}, Classification Validation Loss: {classification_loss_val[-1]}, Segmentation Validation Loss: {segmentation_loss_val[-1]}, Classification Test Loss: {classification_loss_test[-1]}, Segmentation Test Loss: {segmentation_loss_test[-1]}")


IndexError: list index out of range

In [None]:
unet = UNet(model)
unet.train()

optimizer = Adam(unet.parameters(), lr = 1e-4)
segmentation_loss = nn.CrossEntropyLoss()
classification_loss = nn.CrossEntropyLoss()

classification_loss_train = []
segmentation_loss_train = []
classification_loss_val = []
segmentation_loss_val = []
classification_loss_test = []
segmentation_loss_test = []

for i in range(NUM_EPOCHS):
    unet.train()
    train_loss_class = []
    train_loss_seg = []
    val_loss_class = []
    val_loss_seg = []
    test_loss_class = []
    test_loss_seg = []

    for batch in train_loader:
        X_data, y_class, y_seg = batch
        optimizer.zero_grad()
        segmentation_output, classification_output = unet(X_data)
        segmentation_output = torch.sigmoid(segmentation_output)
        classification_output = torch.softmax(classification_output, dim=1)
        seg_loss = segmentation_loss(segmentation_output, y_seg)
        class_loss = classification_loss(classification_output, y_class)
        loss = seg_loss + class_loss
        loss.backward()
        optimizer.step()
        train_loss_class.append(class_loss.item())
        train_loss_seg.append(seg_loss.item())

    classification_loss_train.append(sum(train_loss_class) / len(train_loss_class))
    segmentation_loss_train.append(sum(train_loss_seg) / len(train_loss_seg))

    unet.eval()
    with torch.no_grad():
        val_loss_class = []
        val_loss_seg = []
        for batch in val_loader:
            X_val, y_class_val, y_seg_val = batch
            segmentation_output_val, classification_output_val = unet(X_val)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_val = segmentation_loss(segmentation_output_val, y_seg_val)
            class_loss_val = classification_loss(classification_output_val, y_class_val)
            val_loss_class.append(class_loss_val.item())
            val_loss_seg.append(seg_loss_val.item())
        
        classification_loss_val.append(sum(val_loss_class) / len(val_loss_class))
        segmentation_loss_val.append(sum(val_loss_seg) / len(val_loss_seg))
    
    with torch.no_grad():
        test_loss_class = []
        test_loss_seg = []
        for batch in test_loader:
            X_test, y_class_test, y_seg_test = batch
            segmentation_output_test, classification_output_test = unet(X_test)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_test = segmentation_loss(segmentation_output_test, y_seg_test)
            class_loss_test = classification_loss(classification_output_test, y_class_test)
            test_loss_class.append(class_loss_test.item())
            test_loss_seg.append(seg_loss_test.item())
        
        classification_loss_test.append(sum(test_loss_class) / len(test_loss_class))
        segmentation_loss_test.append(sum(test_loss_seg) / len(test_loss_seg))

    print(f"Epoch: {i}, Classification Training Loss: {classification_loss_train[-1]}, Segmentation Training Loss: {segmentation_loss_train[-1]}, Classification Validation Loss: {classification_loss_val[-1]}, Segmentation Validation Loss: {segmentation_loss_val[-1]}, Classification Test Loss: {classification_loss_test[-1]}, Segmentation Test Loss: {segmentation_loss_test[-1]}")


In [None]:
b1 = 0.8
b2 = 0.85
unet = UNet(model)
unet.train()

optimizer = Adam(unet.parameters(), lr = 1e-3, betas=(b1, b2), weight_decay=0.0005)
segmentation_loss = nn.CrossEntropyLoss()
classification_loss = nn.CrossEntropyLoss()

classification_loss_train = []
segmentation_loss_train = []
classification_loss_val = []
segmentation_loss_val = []
classification_loss_test = []
segmentation_loss_test = []

for i in range(NUM_EPOCHS):
    unet.train()
    train_loss_class = []
    train_loss_seg = []
    val_loss_class = []
    val_loss_seg = []
    test_loss_class = []
    test_loss_seg = []

    for batch in train_loader:
        X_data, y_class, y_seg = batch
        optimizer.zero_grad()
        segmentation_output, classification_output = unet(X_data)
        segmentation_output = torch.sigmoid(segmentation_output)
        classification_output = torch.softmax(classification_output, dim=1)
        seg_loss = segmentation_loss(segmentation_output, y_seg)
        class_loss = classification_loss(classification_output, y_class)
        loss = seg_loss + class_loss
        loss.backward()
        optimizer.step()
        train_loss_class.append(class_loss.item())
        train_loss_seg.append(seg_loss.item())

    classification_loss_train.append(sum(train_loss_class) / len(train_loss_class))
    segmentation_loss_train.append(sum(train_loss_seg) / len(train_loss_seg))

    unet.eval()
    with torch.no_grad():
        val_loss_class = []
        val_loss_seg = []
        for batch in val_loader:
            X_val, y_class_val, y_seg_val = batch
            segmentation_output_val, classification_output_val = unet(X_val)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_val = segmentation_loss(segmentation_output_val, y_seg_val)
            class_loss_val = classification_loss(classification_output_val, y_class_val)
            val_loss_class.append(class_loss_val.item())
            val_loss_seg.append(seg_loss_val.item())
        
        classification_loss_val.append(sum(val_loss_class) / len(val_loss_class))
        segmentation_loss_val.append(sum(val_loss_seg) / len(val_loss_seg))
    
    with torch.no_grad():
        test_loss_class = []
        test_loss_seg = []
        for batch in test_loader:
            X_test, y_class_test, y_seg_test = batch
            segmentation_output_test, classification_output_test = unet(X_test)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_test = segmentation_loss(segmentation_output_test, y_seg_test)
            class_loss_test = classification_loss(classification_output_test, y_class_test)
            test_loss_class.append(class_loss_test.item())
            test_loss_seg.append(seg_loss_test.item())
        
        classification_loss_test.append(sum(test_loss_class) / len(test_loss_class))
        segmentation_loss_test.append(sum(test_loss_seg) / len(test_loss_seg))

    print(f"Epoch: {i}, Classification Training Loss: {classification_loss_train[-1]}, Segmentation Training Loss: {segmentation_loss_train[-1]}, Classification Validation Loss: {classification_loss_val[-1]}, Segmentation Validation Loss: {segmentation_loss_val[-1]}, Classification Test Loss: {classification_loss_test[-1]}, Segmentation Test Loss: {segmentation_loss_test[-1]}")


In [None]:
unet = UNet(model)
unet.train()

optimizer = AdamW(unet.parameters(), lr = 1e-4)
segmentation_loss = nn.CrossEntropyLoss()
classification_loss = nn.CrossEntropyLoss()

classification_loss_train = []
segmentation_loss_train = []
classification_loss_val = []
segmentation_loss_val = []
classification_loss_test = []
segmentation_loss_test = []

for i in range(NUM_EPOCHS):
    unet.train()
    train_loss_class = []
    train_loss_seg = []
    val_loss_class = []
    val_loss_seg = []
    test_loss_class = []
    test_loss_seg = []

    for batch in train_loader:
        X_data, y_class, y_seg = batch
        optimizer.zero_grad()
        segmentation_output, classification_output = unet(X_data)
        segmentation_output = torch.sigmoid(segmentation_output)
        classification_output = torch.softmax(classification_output, dim=1)
        seg_loss = segmentation_loss(segmentation_output, y_seg)
        class_loss = classification_loss(classification_output, y_class)
        loss = seg_loss + class_loss
        loss.backward()
        optimizer.step()
        train_loss_class.append(class_loss.item())
        train_loss_seg.append(seg_loss.item())

    classification_loss_train.append(sum(train_loss_class) / len(train_loss_class))
    segmentation_loss_train.append(sum(train_loss_seg) / len(train_loss_seg))

    unet.eval()
    with torch.no_grad():
        val_loss_class = []
        val_loss_seg = []
        for batch in val_loader:
            X_val, y_class_val, y_seg_val = batch
            segmentation_output_val, classification_output_val = unet(X_val)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_val = segmentation_loss(segmentation_output_val, y_seg_val)
            class_loss_val = classification_loss(classification_output_val, y_class_val)
            val_loss_class.append(class_loss_val.item())
            val_loss_seg.append(seg_loss_val.item())
        
        classification_loss_val.append(sum(val_loss_class) / len(val_loss_class))
        segmentation_loss_val.append(sum(val_loss_seg) / len(val_loss_seg))
    
    with torch.no_grad():
        test_loss_class = []
        test_loss_seg = []
        for batch in test_loader:
            X_test, y_class_test, y_seg_test = batch
            segmentation_output_test, classification_output_test = unet(X_test)
            segmentation_output = torch.sigmoid(segmentation_output)
            classification_output = torch.softmax(classification_output, dim=1)
            seg_loss_test = segmentation_loss(segmentation_output_test, y_seg_test)
            class_loss_test = classification_loss(classification_output_test, y_class_test)
            test_loss_class.append(class_loss_test.item())
            test_loss_seg.append(seg_loss_test.item())
        
        classification_loss_test.append(sum(test_loss_class) / len(test_loss_class))
        segmentation_loss_test.append(sum(test_loss_seg) / len(test_loss_seg))

    print(f"Epoch: {i}, Classification Training Loss: {classification_loss_train[-1]}, Segmentation Training Loss: {segmentation_loss_train[-1]}, Classification Validation Loss: {classification_loss_val[-1]}, Segmentation Validation Loss: {segmentation_loss_val[-1]}, Classification Test Loss: {classification_loss_test[-1]}, Segmentation Test Loss: {segmentation_loss_test[-1]}")
