In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from torch.autograd import Variable

import numpy as np
from sklearn.model_selection import train_test_split
import segmentation_models_pytorch as smp

if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
device = torch.device(dev)


# Function to calculate Dice Similarity Coefficient
def dice_coefficient(outputs, targets):
    intersection = 2.0 * (outputs * targets).sum()
    union = (
        outputs.sum() + targets.sum() + 1e-7
    )  # Adding a small constant to avoid division by zero
    return intersection / union


# Function to train the model
def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        average_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")

        # Validation
        model.eval()
        total_dice = 0.0
        total_val_loss = 0.0

        with torch.no_grad():
            for val_inputs, val_targets in val_loader:
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)

                val_outputs = model(val_inputs)
                val_loss = criterion(val_outputs, val_targets)

                total_val_loss += val_loss.item()
                total_dice += dice_coefficient(
                    (val_outputs.argmax(dim=1) == 1).float(), (val_targets == 1).float()
                )

        average_val_loss = total_val_loss / len(val_loader)
        average_dice = total_dice / len(val_loader)
        print(f"Validation Loss: {average_val_loss}, Dice Coefficient: {average_dice}")


# Load your data
train_images = np.load("PancreasDataset/imagesTr/train_np/100_NPimages.npy")
mask_images = np.load("PancreasDataset/labelsTr/mask_np/100_NPimages.npy")

# Assuming you have a function to split your data into training and validation sets
train_images, val_images, mask_train, mask_val = train_test_split(
    train_images, mask_images, test_size=0.2, random_state=42
)

print("Train Shape", train_images.shape)

##
train_images = torch.from_numpy(train_images).float().unsqueeze(1)
mask_train = torch.from_numpy(mask_train).long()
val_images = torch.from_numpy(val_images).float().unsqueeze(1)
mask_val = torch.from_numpy(mask_val).long()

# Convert data to PyTorch tensors
train_dataset = TensorDataset(torch.Tensor(train_images), torch.Tensor(mask_train))
val_dataset = TensorDataset(torch.Tensor(val_images), torch.Tensor(mask_val))

print("Train Shape", torch.Tensor(train_images).size())

# DataLoader
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Create U-Net model using segmentation-models-pytorch
model = smp.Unet(
    "resnet34", in_channels=1, classes=3
)  # You can choose a different backbone
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
train(model, criterion, optimizer, train_loader, val_loader, num_epochs=40)

Train Shape (4282, 512, 512)
Train Shape torch.Size([4282, 1, 512, 512])
Epoch 1/40, Loss: 0.05133759682874696
Validation Loss: 0.009142849605648653, Dice Coefficient: 0.0
Epoch 2/40, Loss: 0.008513812550192406
Validation Loss: 0.017757814965704315, Dice Coefficient: 0.0
Epoch 3/40, Loss: 0.006379370555262526
Validation Loss: 0.006004615883176118, Dice Coefficient: 0.005007224157452583
Epoch 4/40, Loss: 0.0049961575116854035
Validation Loss: 0.005558005728337083, Dice Coefficient: 0.5164913535118103
Epoch 5/40, Loss: 0.004045126603023842
Validation Loss: 0.005010631363517466, Dice Coefficient: 0.5260334610939026
Epoch 6/40, Loss: 0.003711262316307932
Validation Loss: 0.0032212100245148454, Dice Coefficient: 0.7097445130348206
Epoch 7/40, Loss: 0.0034395506164373864
Validation Loss: 0.0035239476343502164, Dice Coefficient: 0.6942983269691467
Epoch 8/40, Loss: 0.0030426182028988015
Validation Loss: 0.0032035783356218076, Dice Coefficient: 0.6621560454368591
Epoch 9/40, Loss: 0.0023496869

In [3]:
# Clear all variables from the workspace
%reset -f

# Import the garbage collector module
import gc

# Collect garbage to free up memory
gc.collect()

0