### Installing dependencies 

In [None]:
!pip install datasets

In [None]:
!pip install --upgrade albumentations


### Training UNET

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import numpy as np
import albumentations as A
from transformers.modeling_outputs import SemanticSegmenterOutput
from PIL import Image
import matplotlib.pyplot as plt
from datasets import load_dataset
import timm

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SegmentationDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = np.array(item["image"])
        label = np.array(item["label"])

        transformed = self.transform(image=image, mask=label)
        image, label = torch.tensor(transformed['image']).permute(2, 0, 1), torch.LongTensor(transformed['mask'])

        return image, label

ADE_MEAN = (np.array([123.675, 116.280, 103.530]) / 255).tolist()
ADE_STD = (np.array([58.395, 57.120, 57.375]) / 255).tolist()

train_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5),
    A.OneOf([
        A.GridDistortion(p=0.5),
        A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=0.5),
    ], p=0.3),
    A.OneOf([
        A.HueSaturationValue(10,15,10),
        A.CLAHE(clip_limit=2),
        A.RandomBrightnessContrast(),
    ], p=0.3),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

val_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNetDecoder(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.up1 = nn.ConvTranspose2d(in_channels, 512, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(256, 128)
        self.up3 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(128, 64)
        self.up4 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.conv4 = DoubleConv(64, 32)
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.up1(x)
        x = self.conv1(x)
        x = self.up2(x)
        x = self.conv2(x)
        x = self.up3(x)
        x = self.conv3(x)
        x = self.up4(x)
        x = self.conv4(x)
        return self.final_conv(x)

class MAEForSemanticSegmentation(nn.Module):
    def __init__(self, vit_model, num_labels):
        super().__init__()
        self.vit = vit_model
        self.decoder = UNetDecoder(in_channels=self.vit.num_features, num_classes=num_labels)

    def forward(self, pixel_values, labels=None):
        features = self.vit.forward_features(pixel_values)
        
        if features.shape[1] == (self.vit.patch_embed.num_patches + 1):
            patch_embeddings = features[:, 1:, :]
        else:
            patch_embeddings = features

        batch_size, num_patches, hidden_dim = patch_embeddings.shape
        height = width = int((num_patches) ** 0.5)
        feature_map = patch_embeddings.reshape(batch_size, height, width, hidden_dim).permute(0, 3, 1, 2)

        logits = self.decoder(feature_map)

        logits = F.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=0)
            loss = loss_fct(logits, labels)

        return SemanticSegmenterOutput(loss=loss, logits=logits)

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean', ignore_index=255):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=self.ignore_index)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6, ignore_index=255):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.ignore_index = ignore_index

    def forward(self, y_pred, y_true):
        y_pred = F.softmax(y_pred, dim=1)
        y_true_onehot = F.one_hot(y_true, num_classes=y_pred.shape[1]).permute(0, 3, 1, 2).float()

        mask = (y_true != self.ignore_index).float().unsqueeze(1)

        intersection = torch.sum(y_pred * y_true_onehot * mask, dim=[0, 2, 3])
        union = torch.sum((y_pred + y_true_onehot) * mask, dim=[0, 2, 3])

        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

def train_model(model, train_dataloader, val_dataloader, num_epochs=100, learning_rate=5e-5, model_name="", is_cam=False):
    device = get_device()
    print(f"Using device: {device}")
    model.to(device)

    for param in model.vit.parameters():
        param.requires_grad = False

    optimizer = AdamW(model.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    dice_loss = DiceLoss(ignore_index=0).to(device)
    focal_loss = FocalLoss(ignore_index=0).to(device)

    best_dice = 0.0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            pixel_values, labels = batch
            pixel_values, labels = pixel_values.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = dice_loss(outputs.logits, labels) + focal_loss(outputs.logits, labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        dice_scores = []

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Validation"):
                pixel_values, labels = batch
                pixel_values, labels = pixel_values.to(device), labels.to(device)

                outputs = model(pixel_values=pixel_values)
                loss = dice_loss(outputs.logits, labels)
                val_loss += loss.item()

                # Calculate Dice score
                dice = 1 - dice_loss(outputs.logits, labels)
                dice_scores.append(dice.item())

        val_loss /= len(val_dataloader)
        avg_dice = np.mean(dice_scores)
        print(f"Validation Loss: {val_loss:.4f}, Average Dice Score: {avg_dice:.4f}")

        if avg_dice > best_dice:
            best_dice = avg_dice
            best_model_state = model.state_dict()
            print(f"New best model found with Dice Score: {best_dice:.4f}")

        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            cam_type = "CAM" if is_cam else "NO_CAM"
            save_dir = f"{cam_type}/{model_name}"
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}_dice_{avg_dice:.4f}.pth")
            torch.save(model.state_dict(), save_path)
            print(f"Checkpoint saved at epoch {epoch+1}")

    print("Training completed!")
    
    if best_model_state is not None:
        cam_type = "CAM" if is_cam else "NO_CAM"
        save_dir = f"{cam_type}/{model_name}"
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"best_model_dice_{best_dice:.4f}.pth")
        torch.save(best_model_state, save_path)
        print(f"Best model saved with Dice Score: {best_dice:.4f}")

    return model, best_dice

def main(model_list):
    device = get_device()
    print(f"Using device: {device}")

    dataset = load_dataset("Ayushnangia/FUGseg_dilation")
    train_dataset = SegmentationDataset(dataset["train"], transform=train_transform)
    val_dataset = SegmentationDataset(dataset["validation"], transform=val_transform)

    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)

    id2label = {      
        0: "background",
        1: "ulcer",
        2: "rice",
        3: "sausages",
        4: "rice_sausages",
        5: "rice_ulcer",
        6: "sausages_ulcer",
    }

    for model_info in model_list:
        weights_path = model_info['weights_path']
        is_cam = model_info['is_cam']
        
        model_name = os.path.basename(weights_path)
        
        print(f"Training model: {model_name} ({'CAM' if is_cam else 'NO_CAM'})")

        vit_model = timm.create_model('vit_base_patch16_224', pretrained=False)

        state_dict = torch.load(weights_path, map_location=device)
        vit_model.load_state_dict(state_dict, strict=False)

        model = MAEForSemanticSegmentation(vit_model, num_labels=len(id2label))

        _, best_dice = train_model(model, train_dataloader, val_dataloader, num_epochs=100, learning_rate=5e-5, model_name=model_name, is_cam=is_cam)

        print(f"Best Dice Score for {model_name}: {best_dice:.4f}")
        print("-----------------------------")






In [None]:
model_list = [
    {
        'weights_path': '// path for weight path CAM MAE',
        'is_cam': True
    },
    {
        'weights_path': '// path for weight path NON CAM MAE',
        'is_cam': False
    }
]
main(model_list)

### Inference code

In [76]:
vit_model = timm.create_model('vit_base_patch16_224', pretrained=False)

state_dict = torch.load("// path to MAE finetuned model")

id2label = {
    0: "background",
    1: "ulcer",
    2: "rice",
    3: "sausages",
    4: "rice_sausages",
    5: "rice_ulcer",
    6: "sausages_ulcer",
}
model = MAEForSemanticSegmentation(vit_model, num_labels=len(id2label))

model.load_state_dict(state_dict)

<All keys matched successfully>

In [77]:
def inference(model, image_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    image = Image.open(image_path)
    image_np = np.array(image)

    transformed = val_transform(image=image_np)
    image_tensor = torch.tensor(transformed['image']).permute(2, 0, 1).unsqueeze(0).float().to(device)

    with torch.no_grad():
        outputs = model(pixel_values=image_tensor)

    upsampled_logits = F.interpolate(outputs.logits, size=image.size[::-1], mode="bilinear", align_corners=False)
    predicted_map = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()

    return image, predicted_map



In [78]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2

def visualize_ulcer_overlay(image, predicted_map, id2label):
    plt.figure(figsize=(10, 5))

    if isinstance(image, Image.Image):
        image = np.array(image)

    if len(image.shape) == 2:  
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    elif image.shape[2] == 4:  
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

    ulcer_mask = (predicted_map == 1).astype(np.uint8)

    overlay = np.zeros((*image.shape[:2], 3), dtype=np.uint8)
    overlay[ulcer_mask == 1] = [0, 255, 0]  

    alpha = 0.5 
    blended = cv2.addWeighted(image, 1, overlay, alpha, 0)

    plt.imshow(blended)
    plt.title("Original Image with Ulcer Overlay")
    plt.axis('off')

    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='green', edgecolor='green', label='Ulcer')]
    plt.legend(handles=legend_elements, loc='lower right')

    plt.tight_layout()
    plt.show()

In [None]:
image_path = "// path to ulcer image"
image, predicted_map = inference(model, image_path)

visualize_ulcer_overlay(image, predicted_map, id2label)

