# SINet COD10K Detection
Questo notebook implementa SINet per il rilevamento di oggetti mimetizzati.

In [1]:
import os
import glob
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torchvision.models as models
import torch.fft as fft


In [2]:
device = torch.device('mps')

### Parametri ottimali: 
* Adam decay: 1e-4
* Resize: 416 x 416
* Batch: 40
* Epochs: 180 (Provo con 90)

### Dataset Class

In [3]:
class CODDataset(Dataset):
    def __init__(self, image_folder, mask_folder,
                 image_transform=None, mask_transform=None):
        self.image_files = os.listdir(image_folder)
        self.image_folder = image_folder
        self.mask_folder = mask_folder
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        mask_path = os.path.join(self.mask_folder, self.image_files[idx].replace('.jpg', '.png'))

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        return image, mask

    def __len__(self):
        return len(self.image_files)

In [4]:
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
    # transforms.RandomRotation(15),
    # transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.GaussianBlur(kernel_size=(3, 3)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

mask_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
    # transforms.RandomRotation(15),
    transforms.GaussianBlur(kernel_size=(3, 3)),
    transforms.ToTensor()
])


### Backbone feature extraction

In [5]:
class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetBackbone, self).__init__()
        resnet = models.resnet50(pretrained=pretrained)

        self.stage1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu) 
        self.pool = resnet.maxpool 
        self.stage2 = resnet.layer1  
        self.stage3 = resnet.layer2  
        self.stage4 = resnet.layer3 
        self.stage5 = resnet.layer4  


    def fourier_transform(self, x):
        x_cpu = x.detach().cpu().numpy()  # Converti in NumPy
        x_freq = np.fft.fft2(x_cpu, norm="ortho")  # FFT
        x_freq = np.fft.fftshift(x_freq)  # Shift per centrare
        x_freq = np.abs(x_freq)  # Modulo
        return torch.tensor(x_freq, dtype=torch.float32).to(device)  # Torna a Torch su MPS


    def forward(self, x):
        x = self.fourier_transform(x)  # Applicazione della FFT all'input
        x1 = self.stage1(x)
        x1p = self.pool(x1)     
        x2 = self.stage2(x1p)  
        x3 = self.stage3(x2)    
        x4 = self.stage4(x3) 
        x5 = self.stage5(x4)  
        return x1, x2, x3, x4, x5


### Search Class

In [6]:
class SearchModule(nn.Module):
    def __init__(self, in_channels_list=[256, 512, 1024]):
        super(SearchModule, self).__init__()
        self.conv_list = nn.ModuleList([
            nn.Conv2d(in_ch, 256, kernel_size=3, padding=1) 
            for in_ch in in_channels_list
        ])
        self.out_conv = nn.Conv2d(256, 1, kernel_size=1)


    def fourier_transform(self, x):
        x_cpu = x.detach().cpu().numpy()  # Converti in NumPy
        x_freq = np.fft.fft2(x_cpu, norm="ortho")  # FFT
        x_freq = np.fft.fftshift(x_freq)  # Shift per centrare
        x_freq = np.abs(x_freq)  # Modulo
        return torch.tensor(x_freq, dtype=torch.float32).to(device)  # Torna a Torch su MPS


    def forward(self, x2, x3, x4):
        x2_ = self.fourier_transform(self.conv_list[0](x2))            
        x3_ = self.fourier_transform(F.interpolate(self.conv_list[1](x3), size=x2_.shape[2:], mode='bilinear', align_corners=False))
        x4_ = self.fourier_transform(F.interpolate(self.conv_list[2](x4), size=x2_.shape[2:], mode='bilinear', align_corners=False))
        
        fused = x2_ + x3_ + x4_
        coarse_map = self.out_conv(fused)
        coarse_map = torch.sigmoid(coarse_map)
        return coarse_map


### Identification Class

In [7]:
class IdentificationModule(nn.Module):
    def __init__(self, in_channels=2048):
        super(IdentificationModule, self).__init__()
        self.conv_deep = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
        self.refine_conv = nn.Conv2d(256+1, 256, kernel_size=3, padding=1) 
        self.out_conv = nn.Conv2d(256, 1, kernel_size=1)

    def forward(self, x5, coarse_map):

        x5_ = self.conv_deep(x5)   
        x5_up = F.interpolate(x5_, scale_factor=8, mode='bilinear', align_corners=False)

        refine_input = torch.cat([x5_up, coarse_map], dim=1) 

        refine_feat = self.refine_conv(refine_input)         

        out_map = self.out_conv(refine_feat)                  
        out_map = torch.sigmoid(out_map)

        return out_map  

### SINet Class

In [8]:
class SINet(nn.Module):
    def __init__(self, backbone_pretrained=True):
        super(SINet, self).__init__()
        self.backbone = ResNetBackbone(pretrained=backbone_pretrained)
   
        self.search = SearchModule(in_channels_list=[256, 512, 1024])
     
        self.identify = IdentificationModule(in_channels=2048)


    def forward(self, x):

        x1, x2, x3, x4, x5 = self.backbone(x)

        coarse_map = self.search(x2, x3, x4)   

        refine_map = self.identify(x5, coarse_map)  

        out_final = F.interpolate(refine_map, scale_factor=4, mode='bilinear', align_corners=False)

        return out_final, coarse_map

### Evaluation Methods

In [9]:
def compute_batch_metrics(pred, target, threshold=0.5):

    pred_bin = (pred >= threshold).float()

    eps = 1e-7
    batch_size = pred.shape[0]

    acc_list, prec_list, rec_list, f1_list, iou_list = [], [], [], [], []

    for i in range(batch_size):
        p = pred_bin[i].view(-1)   
        t = target[i].view(-1)    

        TP = (p * t).sum().item()
        FP = (p * (1 - t)).sum().item()
        FN = ((1 - p) * t).sum().item()
        TN = ((1 - p) * (1 - t)).sum().item()


        acc = (TP + TN) / (TP + TN + FP + FN + eps)
       
        prec = TP / (TP + FP + eps)

        rec = TP / (TP + FN + eps)

        f1 = 2 * prec * rec / (prec + rec + eps)
  
        union = TP + FP + FN
        iou = TP / (union + eps)

        acc_list.append(acc)
        prec_list.append(prec)
        rec_list.append(rec)
        f1_list.append(f1)
        iou_list.append(iou)


    metrics = {
        'accuracy': np.mean(acc_list),
        'precision': np.mean(prec_list),
        'recall': np.mean(rec_list),
        'f1': np.mean(f1_list),
        'iou': np.mean(iou_list)
    }
    return metrics

### Dice Loss

In [10]:
def dice_loss(pred, target, smooth=1.0):

    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return 1 - ((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth))

### Train Method

In [11]:
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        out_final, out_coarse = model(images)

        loss_final = dice_loss(out_final, masks) + F.binary_cross_entropy(out_final, masks)

        loss_coarse = dice_loss(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest')) \
                      + F.binary_cross_entropy(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest'))

        loss = loss_final + 0.5 * loss_coarse

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

def validate_one_epoch(model, dataloader, device):
    model.eval()
    val_loss = 0.0
    all_acc, all_prec, all_rec, all_f1, all_iou = [], [], [], [], []

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            out_final, out_coarse = model(images)

            loss_final = dice_loss(out_final, masks) + F.binary_cross_entropy(out_final, masks)
            loss_coarse = dice_loss(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest')) \
                          + F.binary_cross_entropy(out_coarse, F.interpolate(masks, size=out_coarse.shape[2:], mode='nearest'))
            loss = loss_final + 0.5 * loss_coarse
            val_loss += loss.item()
            
            batch_metrics = compute_batch_metrics(out_final, masks, threshold=0.5)
            all_acc.append(batch_metrics['accuracy'])
            all_prec.append(batch_metrics['precision'])
            all_rec.append(batch_metrics['recall'])
            all_f1.append(batch_metrics['f1'])
            all_iou.append(batch_metrics['iou'])

    avg_loss = val_loss / len(dataloader)
    avg_metrics = {
        'accuracy': np.mean(all_acc),
        'precision': np.mean(all_prec),
        'recall': np.mean(all_rec),
        'f1': np.mean(all_f1),
        'iou': np.mean(all_iou)
    }
    return avg_loss, avg_metrics

### Test Method

In [12]:
def test_model(model, dataloader, device, threshold=0.5):

    model.eval()
    all_acc, all_prec, all_rec, all_f1, all_iou = [], [], [], [], []

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)
            
            

            out_final, out_coarse = model(images)

            batch_metrics = compute_batch_metrics(out_final, masks, threshold=threshold)
            all_acc.append(batch_metrics['accuracy'])
            all_prec.append(batch_metrics['precision'])
            all_rec.append(batch_metrics['recall'])
            all_f1.append(batch_metrics['f1'])
            all_iou.append(batch_metrics['iou'])

    avg_metrics = {
        'accuracy': np.mean(all_acc),
        'precision': np.mean(all_prec),
        'recall': np.mean(all_rec),
        'f1': np.mean(all_f1),
        'iou': np.mean(all_iou)
    }
    return avg_metrics

### Main

In [13]:
batch_size = 40
num_epochs = 10
lr = 1e-4

train_dataset = CODDataset(
    image_folder="COD10K-v3/Train/Image",
    mask_folder="COD10K-v3/Train/GT_Object",
    image_transform=image_transform,
    mask_transform=mask_transform
)   
val_dataset = CODDataset(
    image_folder="COD10K-v3/Train/Image",
    mask_folder="COD10K-v3/Train/GT_Object",
    image_transform=image_transform,
    mask_transform=mask_transform
)
test_dataset = CODDataset("COD10K-v3/Test/Image",
                        "COD10K-v3/Test/GT_Object",
                        image_transform=image_transform, 
                        mask_transform=mask_transform
                        )

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
model = SINet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, device)
    val_loss, _ = validate_one_epoch(model, val_loader, device)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [None]:
torch.save(model.state_dict(), "sinet_camouflage_fft.pth")
print("Training completato e modello salvato.")

In [None]:
test_metrics = test_model(model, test_loader, device, threshold=0.5)
print("RISULTATI TEST FINALI:")
print(f"  Accuracy = {test_metrics['accuracy']:.3f}")
print(f"  Precision = {test_metrics['precision']:.3f}")
print(f"  Recall = {test_metrics['recall']:.3f}")
print(f"  F1-score = {test_metrics['f1']:.3f}")
print(f"  IoU = {test_metrics['iou']:.3f}")

In [116]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

def visualize_random_samples(model, dataset, device, num_images=8):
    """
    Pesca `num_images` campioni random dal `dataset`,
    fa la predizione e visualizza (input, mask, prediction).
    """
    model.eval()

    # Estraggo `num_images` indici casuali senza rimpiazzo
    random_indices = np.random.choice(len(dataset), size=num_images, replace=False)

    # Liste per accumulare tensori
    images_list = []
    masks_list = []

    for idx in random_indices:
        image, mask = dataset[idx]  # <--- CODDataset.__getitem__(idx)
        images_list.append(image)
        masks_list.append(mask)

    # Stack su dimensione batch
    images_tensor = torch.stack(images_list, dim=0).to(device)
    masks_tensor = torch.stack(masks_list, dim=0).to(device)

    with torch.no_grad():
        out_final, _ = model(images_tensor)  # Estrai solo out_final
        preds = F.interpolate(out_final, size=masks_tensor.shape[2:], mode='bilinear', align_corners=False)
        preds_bin = (preds > 0.5).float()

    images_cpu = images_tensor.cpu()
    masks_cpu = masks_tensor.cpu()
    preds_cpu = preds_bin.cpu()

    fig, axes = plt.subplots(nrows=num_images, ncols=3, figsize=(9, 3*num_images))
    if num_images == 1:
        axes = [axes]

    for i in range(num_images):
        img_np = images_cpu[i].permute(1, 2, 0).numpy()
        mask_np = masks_cpu[i].squeeze(0).numpy()
        pred_np = preds_cpu[i].squeeze(0).numpy()

        axes[i][0].imshow(img_np)
        axes[i][0].set_title("Input Image")
        axes[i][0].axis("off")

        axes[i][1].imshow(mask_np, cmap='gray')
        axes[i][1].set_title("Ground Truth")
        axes[i][1].axis("off")

        axes[i][2].imshow(pred_np, cmap='gray')
        axes[i][2].set_title("Prediction")
        axes[i][2].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
visualize_random_samples(model, test_dataset, device, num_images=10)