In [14]:
import os
from network import modeling
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import models, transforms
from PIL import Image


# Define the dataset class
class MHISTDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, partition="train"):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.partition = partition

        # Filter by train/test split
        self.data = self.data[self.data["Partition"] == partition]

        # Mapping labels to numeric values (SSA = 1, HP = 0)
        self.label_map = {"SSA": 1, "HP": 0}
        self.data["Majority Vote Label"] = self.data["Majority Vote Label"].map(self.label_map)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]["Image Name"]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        label = self.data.iloc[idx]["Majority Vote Label"]

        if self.transform:
            image = self.transform(image)

        label_mask = torch.full((224, 224), label, dtype=torch.long)

        return image, label_mask


# Function to get dataset loaders
def get_mhist_dataloader(csv_file, img_dir, batch_size=16, partition="train"):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    dataset = MHISTDataset(csv_file, img_dir, transform=transform, partition=partition)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader


In [15]:
# Configurations
CSV_PATH = "mhist_annotations.csv"
IMG_DIR = "images/"
BATCH_SIZE = 16
NUM_EPOCHS = 25
LEARNING_RATE = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "models/deeplabv3plus_resnet50.pth"
MODEL_NAME = "deeplabv3plus_resnet50"
NUM_CLASSES = 2  # SSA = 1, HP = 0 (Binary Segmentation)
OUTPUT_STRIDE = 16  # Default for DeepLabV3+

# Load Dataset
dataloaders = {
    "train": get_mhist_dataloader(CSV_PATH, IMG_DIR, batch_size=BATCH_SIZE, partition="train"),
    "val": get_mhist_dataloader(CSV_PATH, IMG_DIR, batch_size=BATCH_SIZE, partition="test"),
}

# Load Pretrained Model
model = modeling.__dict__[MODEL_NAME](num_classes=NUM_CLASSES, output_stride=OUTPUT_STRIDE)

checkpoint = torch.load(MODEL_PATH, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), weights_only=False)

state_dict = checkpoint['model_state']
del state_dict['classifier.classifier.3.weight']
del state_dict['classifier.classifier.3.bias']

model.load_state_dict(state_dict, strict=False)

model.classifier = modeling.DeepLabHeadV3Plus(
    in_channels=2048,        # Number of input channels from ResNet50 (2048)
    low_level_channels=256,  # Low-level feature channels from ResNet50 (typically 256)
    num_classes=NUM_CLASSES, # Set to 2 for binary classification (SSA or HP)
    aspp_dilate=[12, 24, 36] # Default atrous rates for ASPP
)

model.load_state_dict(checkpoint['model_state'], strict=False)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [16]:
# Define Loss and Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [17]:
# Training Loop
def train_model():
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            correct = 0
            total = 0
            
            for images, labels in dataloaders[phase]:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(images)["out"]
                    loss = criterion(outputs, labels)
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                correct += torch.sum(preds == labels.data).item()
                total += labels.size(0)
            
            epoch_loss = running_loss / total
            epoch_acc = correct / total
            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
        
    torch.save(model.state_dict(), "models/mhist_trained.pth")
    print("Training complete! Model saved.")

In [18]:
train_model()

Epoch 1/25


TypeError: new(): invalid data type 'str'