In [None]:
# Mount Drive to access BRISC dataset files
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install torch torchvision numpy opencv-python tqdm

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import os
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Dataset class for BRISC with .jpg images and .png masks
class BRISCDataset(Dataset):
    def __init__(self, img_dir, mask_dir):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        # List .jpg images
        self.images = sorted([f for f in os.listdir(img_dir) if f.endswith('.jpg')])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.img_dir, img_name)
        # map .jpg image to corresponding .png mask
        mask_name = img_name.replace('.jpg', '.png')
        mask_path = os.path.join(self.mask_dir, mask_name)

        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.float32)  # binary mask

        # resize for training
        image = cv2.resize(image, (256, 256))
        mask = cv2.resize(mask, (256, 256))

        image = np.expand_dims(image, axis=0)
        mask = np.expand_dims(mask, axis=0)

        return torch.tensor(image), torch.tensor(mask)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        def CBR(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
        self.enc1 = CBR(1, 64)
        self.enc2 = CBR(64, 128)
        self.enc3 = CBR(128, 256)
        self.enc4 = CBR(256, 512)
        self.pool = nn.MaxPool2d(2)
        self.center = CBR(512, 1024)
        self.dec4 = CBR(1024 + 512, 512)
        self.dec3 = CBR(512 + 256, 256)
        self.dec2 = CBR(256 + 128, 128)
        self.dec1 = CBR(128 + 64, 64)
        self.final = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        center = self.center(self.pool(e4))
        d4 = F.interpolate(center, scale_factor=2, mode='bilinear', align_corners=True)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        d3 = F.interpolate(d4, scale_factor=2, mode='bilinear', align_corners=True)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, scale_factor=2, mode='bilinear', align_corners=True)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, scale_factor=2, mode='bilinear', align_corners=True)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))
        out = torch.sigmoid(self.final(d1))
        return out


In [None]:
import torch.optim as optim

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset paths (use TRAIN folders for training)
img_dir = '/content/drive/MyDrive/BRISC/segmentation_task/train/images'
mask_dir = '/content/drive/MyDrive/BRISC/segmentation_task/train/masks'

# Create dataset and dataloader
dataset = BRISCDataset(img_dir, mask_dir)
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Initialize model, loss, optimizer
model = UNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs} Loss: {epoch_loss:.4f}")

# Save the trained model
model_save_path = '/content/drive/MyDrive/BRISC_UNet.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved at {model_save_path}")


Using device: cuda
Epoch 1/10 Loss: 0.2305
Epoch 2/10 Loss: 0.1022
Epoch 3/10 Loss: 0.0583
Epoch 4/10 Loss: 0.0396
Epoch 5/10 Loss: 0.0285
Epoch 6/10 Loss: 0.0222
Epoch 7/10 Loss: 0.0184
Epoch 8/10 Loss: 0.0151
Epoch 9/10 Loss: 0.0132
Epoch 10/10 Loss: 0.0117
Model saved at /content/drive/MyDrive/BRISC_UNet.pth
