In [10]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os


class CustomImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.samples = []  # lista di tuple (path, label)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

        # Scansione ricorsiva
        for root, dirs, files in os.walk(folder_path):
            for file in sorted(files):
                if file.lower().endswith(('.jpg', '.png')):
                    path = os.path.join(root, file)
                    # Estrae la label dal nome della sottocartella
                    label = os.path.basename(os.path.dirname(path))
                    self.samples.append((path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = Image.open(path).convert("RGB")
        image = self.transform(image)
        return image, label, path


(tensor([[[0.6431, 0.6431, 0.6431,  ..., 0.6471, 0.6471, 0.6471],
          [0.6431, 0.6431, 0.6431,  ..., 0.6471, 0.6471, 0.6471],
          [0.6431, 0.6431, 0.6431,  ..., 0.6471, 0.6471, 0.6471],
          ...,
          [0.6157, 0.5961, 0.5882,  ..., 0.6471, 0.6588, 0.6510],
          [0.6039, 0.5804, 0.5647,  ..., 0.6667, 0.6863, 0.6863],
          [0.5961, 0.5686, 0.5529,  ..., 0.6745, 0.6980, 0.7020]],
 
         [[0.7137, 0.7137, 0.7137,  ..., 0.7137, 0.7137, 0.7137],
          [0.7137, 0.7137, 0.7137,  ..., 0.7137, 0.7137, 0.7137],
          [0.7137, 0.7137, 0.7137,  ..., 0.7137, 0.7137, 0.7137],
          ...,
          [0.4745, 0.4510, 0.4392,  ..., 0.4235, 0.4157, 0.3922],
          [0.4627, 0.4353, 0.4157,  ..., 0.4431, 0.4392, 0.4235],
          [0.4549, 0.4235, 0.4000,  ..., 0.4510, 0.4549, 0.4392]],
 
         [[0.8000, 0.8000, 0.8000,  ..., 0.7922, 0.7922, 0.7922],
          [0.8000, 0.8000, 0.8000,  ..., 0.7922, 0.7922, 0.7922],
          [0.8000, 0.8000, 0.8000,  ...,

In [11]:
from torch.utils.data import DataLoader

training_dataset = CustomImageDataset("./data_example/training")
training_dataset[0]

dataloader = DataLoader(training_dataset, batch_size=32, shuffle=True, num_workers=4)

In [17]:
import torch
import torch.nn.functional as F

def triplet_loss(anchor, positive, negative, margin=1.0):
    pos_dist = F.pairwise_distance(anchor, positive, p=2)
    neg_dist = F.pairwise_distance(anchor, negative, p=2)
    loss = torch.clamp(pos_dist - neg_dist + margin, min=0.0)
    return loss.mean()

import torch.nn as nn
from torchvision.models import resnet18

class EmbeddingNet(nn.Module):
    def __init__(self):
        super().__init__()
        base = resnet18(pretrained=True)
        self.backbone = nn.Sequential(*list(base.children())[:-1])  # rimuovi classificatore
        self.fc = nn.Linear(512, 128)  # embedding finale

    def forward(self, x):
        x = self.backbone(x)   # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.normalize(x, p=2, dim=1)  # normalizza l'embedding
import random

class TripletDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.label_to_indices = self._build_index()

    def _build_index(self):
        label_to_indices = {}
        for idx, (_, label, _) in enumerate(self.base_dataset):
            label_to_indices.setdefault(label, []).append(idx)
        return label_to_indices

    def __getitem__(self, index):
        anchor_img, anchor_label, _ = self.base_dataset[index]

        # Positive sample (stessa classe)
        positive_idx = index
        while positive_idx == index:
            positive_idx = random.choice(self.label_to_indices[anchor_label])
        positive_img, _, _ = self.base_dataset[positive_idx]

        # Negative sample (classe diversa)
        negative_label = random.choice([l for l in self.label_to_indices.keys() if l != anchor_label])
        negative_idx = random.choice(self.label_to_indices[negative_label])
        negative_img, _, _ = self.base_dataset[negative_idx]

        return anchor_img, positive_img, negative_img

    def __len__(self):
        return len(self.base_dataset)
