In [1]:
import json
import os
import random
import shutil
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from torchvision import transforms
from torchvision.transforms import functional as F
from pycocotools.coco import COCO
from pycocotools import mask as coco_mask
from transformers import SamModel, SamProcessor, SamConfig
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from sklearn.metrics import precision_recall_curve, auc, precision_score, recall_score

# Set the random seed for reproducibility
random.seed(42)
torch.manual_seed(42)

# Determine the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Paths
image_folder_path = "Images/"
annotation_path = "annotation/all_80.json"
output_base_path = "finetune_results/"

# Load the annotations
with open(annotation_path, 'r') as f:
    coco_data = json.load(f)

# Get all image file names
image_files = [img['file_name'] for img in coco_data['images']]

# Helper function to save split annotations
def save_split_annotations(file_names, coco_data, output_path):
    # Filter images
    images = [img for img in coco_data['images'] if img['file_name'] in file_names]
    image_ids = [img['id'] for img in images]

    # Filter annotations
    annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] in image_ids]

    # Create new COCO data dictionary
    new_coco_data = {
        'info': coco_data['info'],
        'licenses': coco_data['licenses'],
        'images': images,
        'annotations': annotations,
        'categories': coco_data['categories']
    }

    # Save new annotations
    with open(output_path, 'w') as f:
        json.dump(new_coco_data, f, indent=4)

# Helper function to copy images to their respective folders
def copy_images(file_names, source_folder, destination_folder):
    os.makedirs(destination_folder, exist_ok=True)
    for file_name in file_names:
        shutil.copy(os.path.join(source_folder, file_name), os.path.join(destination_folder, file_name))

# Define data augmentation and normalization transforms
class CustomTransforms:
    def __init__(self, is_train=True):
        if is_train:
            self.transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((1024, 1024)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
        else:
            self.transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((1024, 1024)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

    def __call__(self, image, mask):
        image = self.transforms(image)
        mask = Image.fromarray(mask)
        mask = F.resize(mask, (1024, 1024), interpolation=transforms.InterpolationMode.NEAREST)
        mask = transforms.ToTensor()(mask).long().squeeze(0)  # Ensure mask is 2D
        return image, mask

# Custom dataset class
class PCBXRayDataset(Dataset):
    def __init__(self, image_dir, annotation_file, transforms=None):
        self.image_dir = image_dir
        self.transforms = transforms
        with open(annotation_file, 'r') as f:
            self.coco_data = json.load(f)
        self.image_info = self.coco_data['images']
        self.annotations = {ann['image_id']: ann for ann in self.coco_data['annotations']}
        self.category_info = self.coco_data['categories']

    def __len__(self):
        return len(self.image_info)

    def __getitem__(self, idx):
        img_info = self.image_info[idx]
        image_id = img_info['id']
        image_path = os.path.join(self.image_dir, img_info['file_name'])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
        anns = [ann for ann in self.annotations.values() if ann['image_id'] == image_id]
        for ann in anns:
            rle = coco_mask.frPyObjects(ann['segmentation'], image.shape[0], image.shape[1])
            decoded_mask = coco_mask.decode(rle)
            if len(decoded_mask.shape) == 3:
                decoded_mask = decoded_mask[:, :, 0]  # Take the first channel if it's a multi-channel mask
            mask += np.squeeze(decoded_mask).astype(np.uint8)

        if self.transforms:
            image, mask = self.transforms(image, mask)

        return image, mask

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.LeakyReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.leaky_relu(self.bn1(self.conv1(x)))
        x = self.leaky_relu(self.bn2(self.conv2(x)))
        x = self.leaky_relu(self.bn3(self.conv3(x)))
        return x

class SAMWithDecoder(nn.Module):
    def __init__(self, sam_model, num_classes=6):
        super(SAMWithDecoder, self).__init__()
        self.sam_model = sam_model
        self.decoder = DecoderBlock(256, 256)
        self.final_conv = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, pixel_values):
        vision_outputs = self.sam_model.vision_encoder(pixel_values)
        x = vision_outputs.last_hidden_state  # Assuming last_hidden_state is the feature map
        x = self.decoder(x)
        x = self.final_conv(x)
        return x

# Load the pre-trained SAM model
sam_model = SamModel.from_pretrained('facebook/sam-vit-huge')

# Integrate the decoder
model_with_decoder = SAMWithDecoder(sam_model).to(device)

# Set up the optimizer with Layer-wise Learning Rate Decay
def get_optimizer_with_llrd(model, base_lr=1e-4, lr_decay_factor=0.95):
    seen_params = set()
    optimizer_grouped_parameters = []
    for i, layer in enumerate(model.sam_model.vision_encoder.layers):
        layer_params = [p for n, p in layer.named_parameters() if p not in seen_params]
        seen_params.update(layer_params)
        optimizer_grouped_parameters.append({'params': layer_params, 'lr': base_lr * (lr_decay_factor ** (len(model.sam_model.vision_encoder.layers) - i))})

    decoder_params = [p for n, p in model.named_parameters() if "decoder" in n and p not in seen_params]
    seen_params.update(decoder_params)
    optimizer_grouped_parameters.append({'params': decoder_params, 'lr': base_lr})

    optimizer = optim.AdamW(optimizer_grouped_parameters)
    return optimizer

optimizer = get_optimizer_with_llrd(model_with_decoder)

# Dice Loss Function for Multiple Classes
class DiceLoss(nn.Module):
    def __init__(self, num_classes=6):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.softmax(inputs, dim=1)  # Use softmax for multi-class segmentation
        loss = 0
        for c in range(self.num_classes):
            input_c = inputs[:, c, :, :]
            target_c = (targets == c).float()
            intersection = (input_c * target_c).sum()
            dice = (2. * intersection + smooth) / (input_c.sum() + target_c.sum() + smooth)
            loss += 1 - dice
        return loss / self.num_classes

# Training function
def train_model(train_loader, val_loader, fold, num_epochs=10, num_classes=6):
    model = SAMWithDecoder(sam_model, num_classes=num_classes).to(device)
    criterion = DiceLoss(num_classes=num_classes)
    optimizer = get_optimizer_with_llrd(model)

    best_val_loss = float('inf')
    best_model_path = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            outputs = torch.nn.functional.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)

            # Debug prints
            print(f"Epoch {epoch + 1}, Fold {fold + 1}, outputs shape: {outputs.shape}, masks shape: {masks.shape}")

            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        print(f"Fold {fold + 1}, Epoch {epoch + 1}, Train Loss: {train_loss / len(train_loader)}")

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                outputs = torch.nn.functional.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)

                # Debug prints
                print(f"Epoch {epoch + 1}, Fold {fold + 1}, outputs shape: {outputs.shape}, masks shape: {masks.shape}")

                loss = criterion(outputs, masks)
                val_loss += loss.item()

        print(f"Fold {fold + 1}, Epoch {epoch + 1}, Val Loss: {val_loss / len(val_loader)}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(fold_output_path, f'best_model_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), best_model_path)

    return train_loss / len(train_loader), val_loss / len(val_loader), best_model_path

# Evaluation function
def evaluate_model(model, data_loader, num_classes=6):
    model.eval()
    dice_scores = []
    precision_list = []
    recall_list = []

    with torch.no_grad():
        for images, masks in data_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            outputs = torch.nn.functional.interpolate(outputs, size=masks.shape[-2:], mode='bilinear', align_corners=False)
            outputs = torch.argmax(outputs, dim=1)  # Get the class with the highest score

            for c in range(num_classes):
                pred_mask = (outputs == c).float()
                true_mask = (masks == c).float()
                dice = dice_score(pred_mask.cpu().numpy(), true_mask.cpu().numpy())
                dice_scores.append(dice)

                precision = precision_score(true_mask.cpu().numpy().flatten(), pred_mask.cpu().numpy().flatten(), zero_division=0)
                recall = recall_score(true_mask.cpu().numpy().flatten(), pred_mask.cpu().numpy().flatten(), zero_division=0)

                precision_list.append(precision)
                recall_list.append(recall)

    mean_dice = np.mean(dice_scores)
    mean_precision = np.mean(precision_list)
    mean_recall = np.mean(recall_list)

    return mean_dice, mean_precision, mean_recall

# Dice score function
def dice_score(pred, target, smooth=1):
    intersection = np.sum(pred * target)
    return (2. * intersection + smooth) / (np.sum(pred) + np.sum(target) + smooth)

# Perform 5-fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_idx, val_idx) in enumerate(kf.split(image_files)):
    print(f'Fold {fold + 1}')
    train_files = [image_files[i] for i in train_idx]
    val_files = [image_files[i] for i in val_idx]

    fold_output_path = os.path.join(output_base_path, f'fold_{fold + 1}')
    train_image_path = os.path.join(fold_output_path, 'train')
    val_image_path = os.path.join(fold_output_path, 'val')

    os.makedirs(train_image_path, exist_ok=True)
    os.makedirs(val_image_path, exist_ok=True)

    save_split_annotations(train_files, coco_data, os.path.join(fold_output_path, 'train_annotations.json'))
    save_split_annotations(val_files, coco_data, os.path.join(fold_output_path, 'val_annotations.json'))
    copy_images(train_files, image_folder_path, train_image_path)
    copy_images(val_files, image_folder_path, val_image_path)

    train_dataset = PCBXRayDataset(train_image_path, os.path.join(fold_output_path, 'train_annotations.json'), transforms=CustomTransforms(is_train=True))
    val_dataset = PCBXRayDataset(val_image_path, os.path.join(fold_output_path, 'val_annotations.json'), transforms=CustomTransforms(is_train=False))

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

    train_loss, val_loss, best_model_path = train_model(train_loader, val_loader, fold, num_classes=6)

    # Evaluate on validation set
    model_with_decoder.load_state_dict(torch.load(best_model_path))
    mean_dice, mean_precision, mean_recall = evaluate_model(model_with_decoder, val_loader, num_classes=6)

    fold_results.append({
        'fold': fold + 1,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'best_model_path': best_model_path,
        'mean_dice': mean_dice,
        'mean_precision': mean_precision,
        'mean_recall': mean_recall,
    })

# Compute average results
avg_train_loss = np.mean([res['train_loss'] for res in fold_results])
avg_val_loss = np.mean([res['val_loss'] for res in fold_results])
avg_dice = np.mean([res['mean_dice'] for res in fold_results])
avg_precision = np.mean([res['mean_precision'] for res in fold_results])
avg_recall = np.mean([res['mean_recall'] for res in fold_results])

# Find the best model based on dice score and precision
best_model = max(fold_results, key=lambda x: (x['mean_dice'], x['mean_precision']))

print("Cross-validation results:", fold_results)
print(f"Average Train Loss: {avg_train_loss}")
print(f"Average Val Loss: {avg_val_loss}")
print(f"Average Dice Score: {avg_dice}")
print(f"Average Precision: {avg_precision}")
print(f"Average Recall: {avg_recall}")
print(f"Best model path: {best_model['best_model_path']}")


Using device: cpu
Fold 1
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape: torch.Size([1, 6, 1024, 1024]), masks shape: torch.Size([1, 1024, 1024])
Epoch 1, Fold 1, outputs shape