In [None]:
import os
import time
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset

In [None]:
!pip install segmentation-models-pytorch albumentations --quiet


In [None]:
import segmentation_models_pytorch as smp
import albumentations as A
import cv2
import numpy as np
import pandas as pd
from glob import glob
from torchvision import transforms as T
from tqdm.notebook import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class MriDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row['image_filename'], cv2.IMREAD_COLOR)
        mask = cv2.imread(row['mask_filename'], cv2.IMREAD_GRAYSCALE)

        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img, mask = augmented['image'], augmented['mask']

        img = T.functional.to_tensor(img)
        mask = torch.tensor(mask // 255, dtype=torch.float32).unsqueeze(0)
        return img, mask

In [None]:
transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

In [None]:
train_df = pd.read_csv("/content/drive/MyDrive/mri_segmentation/lgg-mri-segmentation/kaggle_3m/data.csv")
valid_df = pd.read_csv("valid_data.csv")

train_dataset = MriDataset(train_df, transform=transform)
valid_dataset = MriDataset(valid_df, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=2)

In [None]:
model = smp.Unet(
    encoder_name="efficientnet-b7", encoder_weights="imagenet", in_channels=3, classes=1, activation='sigmoid'
).to(device)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b7-dcc49843.pth
100%|██████████| 254M/254M [00:13<00:00, 19.3MB/s]


In [None]:
loss_fn = smp.losses.DiceLoss(mode='binary')
optimizer = Adam(model.parameters(), lr=0.001)
lr_scheduler = ReduceLROnPlateau(optimizer, patience=2, factor=0.2)

In [None]:
def train_model(model, train_loader, valid_loader, epochs=10):
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for img, mask in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            img, mask = img.to(device), mask.to(device)
            optimizer.zero_grad()
            pred = model(img)
            loss = loss_fn(pred, mask)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for img, mask in valid_loader:
                img, mask = img.to(device), mask.to(device)
                pred = model(img)
                val_loss += loss_fn(pred, mask).item()

        print(f"Epoch {epoch+1}, Train Loss: {epoch_loss/len(train_loader):.4f}, Validation Loss: {val_loss/len(valid_loader):.4f}")
        lr_scheduler.step(val_loss)

    # Save trained model
    torch.save(model.state_dict(), "brain_tumor_segmentation.pth")
    print("Model saved.")

# Train model
train_model(model, train_loader, valid_loader)

In [None]:
def predict(image_path, model_path="brain_tumor_segmentation.pth"):
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    transform = A.Compose([
        A.Resize(256, 256),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])
    image = transform(image=image)['image']
    image = T.functional.to_tensor(image).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(image).squeeze().cpu().numpy()
    return pred
