In [10]:
!pip install wandb --quiet # Install WandB
!pip install pytorch_metric_learning --quiet #Install the Pytorch Metric Library

In [11]:
import torch
from torchsummary import summary
import torchvision
from torchvision import transforms
from torchvision.transforms import v2
import torch.nn.functional as F
import os
import gc
from tqdm import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn import metrics as mt
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import glob
import wandb
import matplotlib.pyplot as plt
from pytorch_metric_learning import samplers
import csv

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

Device:  cuda


In [12]:
from google.colab import drive # Link to your drive if you are not using Colab with GCP
drive.mount('/content/drive') # Models in this HW take a long time to get trained and make sure to save it here

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
"get data from google drive"
!mkdir /content/data

!mv "/content/drive/MyDrive/testdata.zip" /content/

!unzip -q testdata.zip -d /content/data

mkdir: cannot create directory ‘/content/data’: File exists
mv: cannot stat '/content/drive/MyDrive/testdata.zip': No such file or directory
replace /content/data/testdata/EM/SEMdauer2_Mei_export_s058.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/data/testdata/EM/SEMdauer2_Mei_export_s128.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/data/testdata/EM/SEMdauer2_Mei_export_s222.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/data/testdata/EM/SEMdauer2_Mei_export_s251.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/data/testdata/EM/SEMdauer2_Mei_export_s271.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/data/testdata/EM/SEMdauer2_Mei_export_s407.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/data/testdata/EM/SEMdauer2_Mei_export_s512.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/data/testdata/EM/SEMdauer2_Mei_export_s775.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
replace /content/da

#Data proccessing


In [18]:

# Define the input directories for EM and label images and output target directory
em_folder = "/content/data/testdata/EM"
label_folder = "/content/data/testdata/Label"
target_dir = "/content/data/output_patches"
output_size = 256  # Define the size of each patch

# Ensure the output directories exist for both EM and label patches
em_target_dir = os.path.join(target_dir, "EM_patches")
label_target_dir = os.path.join(target_dir, "label_patches")
os.makedirs(em_target_dir, exist_ok=True)
os.makedirs(label_target_dir, exist_ok=True)

# Define the function to create overlapping patches
def break_image_overlap(image, patch_size):
    """Divide an image into overlapping patches of size (patch_size, patch_size)."""
    patches = []
    img_height, img_width = image.shape[0], image.shape[1]
    step = patch_size // 2  # 50% overlap

    for i in range(0, img_height - patch_size + 1, step):
        for j in range(0, img_width - patch_size + 1, step):
            patch = image[i:i + patch_size, j:j + patch_size]
            patches.append(patch)

    return np.stack(patches, axis=2)  # Stack patches along the third axis

# Helper function to find matching label file by suffix
def find_matching_label_file(em_filename):
    suffix = em_filename[-10:]  # Adjust the length based on the identifying pattern
    for label_filename in os.listdir(label_folder):
        if label_filename.endswith(suffix):
            return os.path.join(label_folder, label_filename)
    return None

# Main loop to process images from both folders
for em_filename in os.listdir(em_folder):
    # Construct the file path for the current EM image
    em_image_path = os.path.join(em_folder, em_filename)

    # Find the matching label image based on suffix
    label_image_path = find_matching_label_file(em_filename)

    # Process the pair if a matching label file is found
    if label_image_path:
        try:
            # Process EM Image
            em_image = Image.open(em_image_path).convert('L')  # Convert to grayscale
            em_image_np = np.array(em_image)
            em_patches = break_image_overlap(em_image_np, output_size)

            # Save EM patches
            for patch_idx in range(em_patches.shape[2]):
                patch = em_patches[:, :, patch_idx]
                patch_image = Image.fromarray(patch)
                em_patch_filename = os.path.join(em_target_dir, f"{os.path.splitext(em_filename)[0]}_em_{patch_idx + 1}.png")
                patch_image.save(em_patch_filename)

            # Process Label Image
            label_image = Image.open(label_image_path).convert('L')
            label_image_np = np.array(label_image)
            label_patches = break_image_overlap(label_image_np, output_size)

            # Save Label patches
            for patch_idx in range(label_patches.shape[2]):
                patch = label_patches[:, :, patch_idx]
                patch_image = Image.fromarray(patch)
                label_patch_filename = os.path.join(label_target_dir, f"{os.path.splitext(em_filename)[0]}_label_{patch_idx + 1}.png")
                patch_image.save(label_patch_filename)

        except Exception as e:
            print(f"Error processing {em_filename} and {label_image_path}: {e}")
    else:
        print(f"No matching label file found for {em_filename}. Skipping...")


print("Processing completed.")


Processing completed.


In [19]:
#check the shapes:
for filename in os.listdir('/content/data/output_patches/EM_patches'):
    if filename.endswith(".png"):
        image_path = os.path.join(em_target_dir, filename)
        image = Image.open(image_path)
        print(image.size)

(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)

KeyboardInterrupt: 

In [20]:
#seperate data into val and train data
import os
import shutil
import random

# Original data directories
em_dir = '/content/data/output_patches/EM_patches'
label_dir = '/content/data/output_patches/label_patches'

# New directories for training and validation sets
train_em_dir = '/content/data/train/images'
train_label_dir = '/content/data/train/labels'
val_em_dir = '/content/data/val/images'
val_label_dir = '/content/data/val/labels'

# Create directories if they don't exist
os.makedirs(train_em_dir, exist_ok=True)
os.makedirs(train_label_dir, exist_ok=True)
os.makedirs(val_em_dir, exist_ok=True)
os.makedirs(val_label_dir, exist_ok=True)

# Get list of filenames in the EM images directory
filenames = [f for f in os.listdir(em_dir) if os.path.isfile(os.path.join(em_dir, f))]

# Ensure the filenames are the same in both directories
label_filenames = [f for f in os.listdir(label_dir) if os.path.isfile(os.path.join(label_dir, f))]
assert set(filenames) == set(label_filenames), "Mismatch between EM images and labels filenames."

# Shuffle filenames
random.seed(42)  # For reproducibility
random.shuffle(filenames)

# Define split ratio
split_ratio = 0.8  # 80% training, 20% validation

# Calculate split index
split_index = int(len(filenames) * split_ratio)

# Split filenames
train_filenames = filenames[:split_index]
val_filenames = filenames[split_index:]

# Function to copy files
def copy_files(filenames, src_image_dir, src_label_dir, dest_image_dir, dest_label_dir):
    for filename in filenames:
        # Source file paths
        src_image_path = os.path.join(src_image_dir, filename)
        src_label_path = os.path.join(src_label_dir, filename)

        # Destination file paths
        dest_image_path = os.path.join(dest_image_dir, filename)
        dest_label_path = os.path.join(dest_label_dir, filename)

        # Copy image and label
        shutil.copy2(src_image_path, dest_image_path)
        shutil.copy2(src_label_path, dest_label_path)

# Copy training files
copy_files(train_filenames, em_dir, label_dir, train_em_dir, train_label_dir)

# Copy validation files
copy_files(val_filenames, em_dir, label_dir, val_em_dir, val_label_dir)

print("Dataset splitting completed successfully.")


AssertionError: Mismatch between EM images and labels filenames.

# Model

BaseLine model:

In [24]:
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class SEBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se = SELayer(planes * 4, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out

class SEResNet(nn.Module):
    def __init__(self, block, layers, num_classes=2):
        super(SEResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.final_conv = nn.Conv2d(512 * block.expansion, num_classes, kernel_size=1)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.final_conv(x)
        x = nn.functional.interpolate(x, scale_factor=32, mode='bilinear', align_corners=False)
        return x

def se_resnet50_segmentation(num_classes=2, pretrained=False):
    model = SEResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes)
    return model


defining model

In [25]:
Network = se_resnet50_segmentation(num_classes=2)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Network.to(DEVICE)

summary(model, (3, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 56, 56]           9,408
       BatchNorm2d-2           [-1, 64, 56, 56]             128
              ReLU-3           [-1, 64, 56, 56]               0
         MaxPool2d-4           [-1, 64, 28, 28]               0
            Conv2d-5           [-1, 64, 28, 28]           4,096
       BatchNorm2d-6           [-1, 64, 28, 28]             128
              ReLU-7           [-1, 64, 28, 28]               0
            Conv2d-8           [-1, 64, 28, 28]          36,864
       BatchNorm2d-9           [-1, 64, 28, 28]             128
             ReLU-10           [-1, 64, 28, 28]               0
           Conv2d-11          [-1, 256, 28, 28]          16,384
      BatchNorm2d-12          [-1, 256, 28, 28]             512
AdaptiveAvgPool2d-13            [-1, 256, 1, 1]               0
           Linear-14                   

# Training

In [26]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets, models
import os
from PIL import Image
from tqdm import tqdm

# Define hyperparameters and configurations
image_size = (256, 256)
num_classes = 2
batch_size = 16  # Adjust as per system memory
learning_rate = 0.001
num_epochs = 20

# Directories for images and labels
image_dir = "/content/data/output_patches/EM_patches"
label_dir = "/content/data/output_patches/label_patches"

# Transformations for images and labels
data_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

# Custom Dataset for segmentation
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))
        self.label_files = sorted(os.listdir(label_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.image_dir, self.image_files[idx])).convert("RGB")
        label = Image.open(os.path.join(self.label_dir, self.label_files[idx])).convert("L")

        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label.squeeze(0).long()  # Make label compatible with CrossEntropyLoss

# Create datasets and split into train, val, and test sets
dataset = SegmentationDataset(image_dir=image_dir, label_dir=label_dir, transform=data_transforms)
train_size = int(0.99 * len(dataset))
val_size = int(0.01 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define loss function with class weights (example weights, adjust based on dataset class distribution)
class_weights = torch.tensor([1.0, 2.0]).to(DEVICE)  # Example weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training and Validation Loop
for epoch in range(num_epochs):
    # Training phase
    model.train()
    loss_m = AverageMeter()
    running_loss = 0.0
    train_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5)
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        loss.backward()
        optimizer.step()

        loss_m.update(loss.item())

        running_loss += loss.item() * images.size(0)
        train_bar.set_postfix(
            # acc         = "{:.04f}%".format(100*accuracy),
            loss        = "{:.04f} ({:.04f})".format(loss.item(), loss_m.avg),
            lr          = "{:.04f}".format(float(optimizer.param_groups[0]['lr'])))

        train_bar.update() # Update tqdm bar

    train_bar.close()
    epoch_train_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_train_loss:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    loss_m = AverageMeter()
    val_bar = tqdm(total=len(val_loader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5)
    with torch.no_grad():  # No gradient calculation for validation
        for i, (images, labels) in enumerate(val_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            loss_m.update(loss.item())
            val_bar.set_postfix(
              loss        = "{:.04f} ({:.04f})".format(loss.item(), loss_m.avg))

            val_bar.update()

    epoch_val_loss = val_loss / len(val_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {epoch_val_loss:.4f}")

# Saving the trained model
torch.save(model.state_dict(), "segmentation_model.pth")




Epoch [1/20], Training Loss: 0.0004


Train: 100%|██████████| 26/26 [00:01<00:00, 15.73it/s, loss=0.0000 (0.0000)]

Epoch [1/20], Validation Loss: 0.0000


Train:  31%|███▏      | 808/2567 [01:14<02:39, 10.99it/s, loss=0.0000 (0.0000), lr=0.0010]