In [1]:
print('Hello')

Hello


In [2]:
import torch
import os
import random
from PIL import Image
from collections import defaultdict
import re
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

def is_image_file(file_path):
    # Common image file extensions
    image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]
    # Get file extension
    ext = os.path.splitext(file_path)[-1].lower()
    return ext in image_extensions


def is_valid_part_format(s):
    # Define the pattern
    pattern = r"^part([1-9]|1[0-4])$"
    # Match the string against the pattern
    match = re.match(pattern, s)
    return bool(match)


class TripletDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root_dir,
        transform=None,
        sample_negatives="epoch",
        limit=-1,
        neg_only_reviews=True,
        type='train'
    ):
        """
        root_dir: Path to dataset (folders as classes)
        transform: Image transformations (e.g., augmentation, normalization)
        sample_negatives:
            - "batch" → Selects a random negative for each sample dynamically.
            - "epoch" → Assigns a negative at the start of each epoch.
            - "fixed" → Precomputed negative samples from a CSV file.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.sample_negatives = sample_negatives
        self.neg_only_reviews = neg_only_reviews

        self.class_to_images = defaultdict(list)  # { class: [image1, image2, ...] }
        self.samples = []  # [(anchor_path, positive_path, class)]

        # Read dataset structure
        for part_folder in os.listdir(root_dir):
            if self.__len__() >= limit and limit > 0:
                break
            if type == 'train':
                part_path = os.path.join(root_dir, part_folder, part_folder)
            else:
                part_path = os.path.join(root_dir, part_folder)
            print(part_path)
            if not os.path.isdir(part_path) or not is_valid_part_format(part_folder):
                continue
            for product_folder in os.listdir(part_path):
                if self.__len__() >= limit and limit > 0:
                    break

                product_path = os.path.join(part_path, product_folder)
                if os.path.isdir(product_path):
                    product_and_review = [
                        os.path.join(product_path, img)
                        for img in os.listdir(product_path)
                    ]
                    if (
                        len(product_and_review) < 2
                    ):  # Ensure at least an anchor-positive pair
                        continue

                    positive = None
                    anchor = None
                    for i in product_and_review:
                        if os.path.isdir(i):  # review
                            reviews = [
                                os.path.join(i, review_img)
                                for review_img in os.listdir(i)
                            ]
                            if len(reviews) == 0:
                                continue
                            # Get only first review image if it have multiple reivews
                            valid_review_imgs = [
                                file for file in reviews if is_image_file(file)
                            ]
                            if len(valid_review_imgs) > 0:
                                positive = valid_review_imgs[0]
                        elif is_image_file(i) and anchor is None:
                            anchor = i

                        if positive is not None and anchor is not None:
                            self.class_to_images[product_folder] = [
                                anchor,
                                positive,
                            ]
                            self.samples.append(
                                (anchor, positive, product_folder)
                            )  # Anchor & positive
                            continue

        # Precompute negatives if needed
        if self.sample_negatives == "epoch":
            self.negative_map = self.assign_negatives()

    def assign_negatives(self):
        """Assigns a random negative from a different class at the start of each epoch."""
        negative_map = {}
        product_list = list(self.class_to_images.keys())

        for product_label in self.class_to_images:
            neg_reviews = [cls for cls in product_list if cls != product_label]
            neg_review = random.choice(neg_reviews)
            if not self.neg_only_reviews:
                negative_map[product_label] = random.choice(
                    self.class_to_images[neg_review]
                )
            else:
                negative_map[product_label] = self.class_to_images[neg_review][1]

        return negative_map

    def __getitem__(self, index):
        anchor_path, positive_path, product_label = self.samples[index]

        # Choose negative based on sampling strategy
        if self.sample_negatives == "batch":
            neg_reviews = [
                cls for cls in self.class_to_images.keys() if cls != product_label
            ]
            neg_review = random.choice(neg_reviews)
            negative_path = random.choice(self.class_to_images[neg_review])
        elif self.sample_negatives == "epoch":
            negative_path = self.negative_map[product_label]
        else:
            raise ValueError("Unsupported sampling strategy. Use 'batch' or 'epoch'.")

        # Load images
        anchor = Image.open(anchor_path).convert("RGB")
        positive = Image.open(positive_path).convert("RGB")
        # negative = Image.open(negative_path).convert("RGB")

        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            # negative = self.transform(negative)

        return anchor, positive

    def __len__(self):
        return len(self.samples)

    def update_negatives(self):
        """Call this at the start of each epoch if using 'epoch' sampling."""
        if self.sample_negatives == "epoch":
            self.negative_map = self.assign_negatives()



class EvalTripletDataset(TripletDataset):

    def assign_negatives(self):
        """
        Gán các mẫu âm (negative samples) cho từng sản phẩm.

        Mô tả:
        - Duyệt qua danh sách các sản phẩm (`class_to_images`).
        - Mỗi sản phẩm được gán với một sản phẩm khác làm mẫu âm (negative).
        - Sử dụng kỹ thuật xoay vòng (circular indexing) để đảm bảo mỗi sản phẩm có một mẫu âm hợp lệ.

        Returns:
            dict: Bản đồ ánh xạ từ sản phẩm gốc sang mẫu âm.
        """
        negative_map = {}
        product_list = list(self.class_to_images.keys())
        total_products = len(product_list)

        for idx, product_label in enumerate(self.class_to_images):
            next_idx = (idx + 1) % total_products  # Xoay vòng danh sách
            neg_reviews = product_list[next_idx]
            negative_map[product_label] = self.class_to_images[neg_reviews][1]

        return negative_map

    def __getitem__(self, index):
        anchor_path, positive_path, product_label = self.samples[index]

        # Choose negative based on sampling strategy
        if self.sample_negatives == "batch":
            neg_reviews = [
                cls for cls in self.class_to_images.keys() if cls != product_label
            ]
            neg_review = random.choice(neg_reviews)
            negative_path = random.choice(self.class_to_images[neg_review])
        elif self.sample_negatives == "epoch":
            negative_path = self.negative_map[product_label]
        else:
            raise ValueError("Unsupported sampling strategy. Use 'batch' or 'epoch'.")

        # Load images
        anchor = Image.open(anchor_path).convert("RGB")
        positive = Image.open(positive_path).convert("RGB")
        negative = Image.open(negative_path).convert("RGB")

        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative

In [3]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt


def test_image_embedding_model(
    image1_path, image2_path, model, threshold, metric="l2", show_images=True
):
    """
    Kiểm tra mức độ tương đồng của hai ảnh dựa trên model embedding.

    Args:
        image1_path (str): Đường dẫn ảnh thứ nhất.
        image2_path (str): Đường dẫn ảnh thứ hai.
        model (torch.nn.Module): Mô hình trích xuất embedding.
        threshold (float): Ngưỡng phân biệt giữa ảnh tương đồng và không tương đồng.
        metric (str): Phương pháp đo lường, chọn 'cosine' hoặc 'l2'.
        show_images (bool): Nếu True, hiển thị hai ảnh để so sánh trực quan.

    Returns:
        int: 1 nếu hai ảnh tương đồng, 0 nếu không tương đồng.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Tiền xử lý ảnh
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    def preprocess_image(image_path):
        image = Image.open(image_path).convert("RGB")
        return transform(image).unsqueeze(0)  # Thêm batch dimension

    img1 = preprocess_image(image1_path)
    img2 = preprocess_image(image2_path)
    img1, img2 = img1.to(device), img2.to(device)
    # Chuyển ảnh sang tensor và đưa vào model để lấy embedding
    model.eval()
    with torch.no_grad():
        emb1 = model(img1).squeeze(0)  # (N, D) → (D,)
        emb2 = model(img2).squeeze(0)  # (N, D) → (D,)

    # Tính toán khoảng cách hoặc độ tương đồng
    if metric == "cosine":
        similarity = F.cosine_similarity(emb1, emb2, dim=0).item()
        score_text = f"Similarity: {similarity:.4f}"
        result = 1 if similarity >= threshold else 0
    elif metric == "l2":
        distance = torch.norm(emb1 - emb2, p=2).item()
        score_text = f"Distance: {distance:.4f}"
        result = 1 if distance <= threshold else 0
    else:
        raise ValueError("metric phải là 'cosine' hoặc 'l2'")

    # Hiển thị ảnh nếu được yêu cầu
    if show_images:
        fig, axes = plt.subplots(1, 2, figsize=(8, 4))
        axes[0].imshow(Image.open(image1_path))
        axes[0].set_title("Image 1")
        axes[0].axis("off")

        axes[1].imshow(Image.open(image2_path))
        axes[1].set_title("Image 2")
        axes[1].axis("off")

        plt.suptitle(score_text)
        plt.show()

    print(score_text)

    return result


def evaluate_batch(anchor, pos, neg, threshold=1.25, metric="l2"):
    """
    Đánh giá theo batch với metric là cosine similarity hoặc L2 distance.

    Args:
        anchor (torch.Tensor): Batch embedding của anchor, shape (batch_size, embedding_dim)
        pos (torch.Tensor): Batch embedding của positive, shape (batch_size, embedding_dim)
        neg (torch.Tensor): Batch embedding của negative, shape (batch_size, embedding_dim)
        threshold (float): Ngưỡng quyết định mẫu có giống nhau không.
        metric (str): 'cosine' hoặc 'l2' để chọn phương pháp đo khoảng cách.

    Returns:
        tuple: (TP, TN, FP, FN)
    """

    if metric == "cosine":
        # Tính cosine similarity
        sim_pos = F.cosine_similarity(
            anchor, pos, dim=-1
        )  # Cosine similarity giữa anchor và positive
        sim_neg = F.cosine_similarity(
            anchor, neg, dim=-1
        )  # Cosine similarity giữa anchor và negative

        # Xác định TP, FP, TN, FN
        tp = (sim_pos >= threshold).sum().item()  # Dự đoán đúng positive
        fn = (
            (sim_pos < threshold).sum().item()
        )  # Dự đoán sai positive (đáng lẽ giống nhưng bị xem là khác)
        tn = (sim_neg < threshold).sum().item()  # Dự đoán đúng negative
        fp = (
            (sim_neg >= threshold).sum().item()
        )  # Dự đoán sai negative (đáng lẽ khác nhưng bị xem là giống)

    elif metric == "l2":
        # Tính L2 distance
        dist_pos = torch.norm(
            anchor - pos, p=2, dim=-1
        )  # Khoảng cách L2 giữa anchor và positive
        dist_neg = torch.norm(
            anchor - neg, p=2, dim=-1
        )  # Khoảng cách L2 giữa anchor và negative

        # Xác định TP, FP, TN, FN
        tp = (dist_pos <= threshold).sum().item()  # Dự đoán đúng positive
        fn = (dist_pos > threshold).sum().item()  # Dự đoán sai positive
        tn = (dist_neg > threshold).sum().item()  # Dự đoán đúng negative
        fp = (dist_neg <= threshold).sum().item()  # Dự đoán sai negative

    else:
        raise ValueError("Metric must be 'cosine' or 'l2'")

    return tp, tn, fp, fn


def evaluate_metrics(tp, tn, fp, fn):
    # Accuracy
    accuracy = (tp + tn) / (tp + tn + fp + fn)

    # Precision
    precision = tp / (tp + fp) if (tp + fp) != 0 else 0

    # Recall
    recall = tp / (tp + fn) if (tp + fn) != 0 else 0

    # F1-Score
    f1_score = (
        2 * (precision * recall) / (precision + recall)
        if (precision + recall) != 0
        else 0
    )

    # False Positive Rate (FPR)
    fpr = fp / (fp + tn) if (fp + tn) != 0 else 0

    # False Negative Rate (FNR)
    fnr = fn / (fn + tp) if (fn + tp) != 0 else 0

    return {
        "Accuracy": accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1-Score": f1_score,
        "False Positive Rate (FPR)": fpr,
        "False Negative Rate (FNR)": fnr,
    }

In [4]:
import torch
import wandb
from typing import Literal


class SpamDetectorTrainer:
    def __init__(
        self,
        model: torch.nn.Module,
        criterion: torch.nn.Module,
        train_loader: torch.utils.data.DataLoader,
        valid_loader: torch.utils.data.DataLoader,
        optimizer: torch.optim.Optimizer,
        scheduler=None,
        lr_types: Literal["step", "epoch"] = "step",
        device: torch.device = "cuda",
        epochs: int = 10,
        max_norm: float = 0,
        log_writer: wandb = None,
        patience=3,
    ):
        self.model = model.to(device)
        self.criterion = criterion
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.lr_types = lr_types
        self.device = device
        self.epochs = epochs
        self.max_norm = max_norm
        self.log_writer = log_writer
        self.patience = patience

    def train_one_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        global_step = epoch * len(self.train_loader)  # Global step tracking🔥

        for step, (anchor, positive) in enumerate(self.train_loader):
            anchor, positive = (
                anchor.to(self.device),
                positive.to(self.device),
            )

            self.optimizer.zero_grad()
            anchor_embeded = self.model(anchor)
            positive_embeded = self.model(positive)
            loss = self.criterion(anchor_embeded, positive_embeded)

            loss.backward()

            # Tính toán Gradient Norm
            total_norm = 0
            param_count = 0

            for param in self.model.parameters():
                if param.grad is not None:
                    param_norm = param.grad.norm().item()
                    total_norm += param_norm
                    param_count += 1

            mean_grad_norm = total_norm / param_count if param_count > 0 else 0

            # Áp dụng Gradient Clipping nếu cần
            if self.max_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm)

            self.optimizer.step()

            # Step scheduler per step
            if self.scheduler and self.lr_types == "step":
                self.scheduler.step()

            # Log to W&B
            if self.log_writer:
                self.log_writer.log(
                    {
                        "train/loss": loss.item(),
                        "train/learning_rate": self.optimizer.param_groups[0]["lr"],
                        "train/grad_norm": mean_grad_norm,
                        "train/global_step": global_step + step,
                        "train/epoch": epoch + step / len(self.train_loader),
                    }
                )

            total_loss += loss.item()

        epoch_loss = total_loss / len(self.train_loader)
        if self.log_writer:
            self.log_writer.log({"train/mean_loss": epoch_loss})

        # Step scheduler per epoch
        if self.scheduler and self.lr_types == "epoch":
            self.scheduler.step()  # Step based on epoch

        return epoch_loss

    def validate(self):
        self.model.eval()
        running_loss = 0.0
        tp = 0
        tn = 0
        fp = 0
        fn = 0

        with torch.no_grad():
            for step, (anchor, positive, negative) in enumerate(self.valid_loader):
                anchor, positive, negative = (
                    anchor.to(self.device),
                    positive.to(self.device),
                    negative.to(self.device),
                )
                anchor_embeded = self.model(anchor)
                positive_embeded = self.model(positive)
                negative_embeded = self.model(negative)
                loss = self.criterion(
                    anchor_embeded, positive_embeded
                )
                running_loss += loss.item()

                eval_step_result = evaluate_batch(
                    anchor_embeded, positive_embeded, negative_embeded
                )
                tp += eval_step_result[0]
                tn += eval_step_result[1]
                fp += eval_step_result[2]
                fn += eval_step_result[3]

        metrics = evaluate_metrics(tp, tn, fp, fn)

        # Calculate average loss
        avg_loss = running_loss / len(self.valid_loader)

        # Log metrics and loss to W&B
        if self.log_writer:
            # Log evaluation metrics
            self.log_writer.log(
                {
                    "val/accuracy": metrics["Accuracy"],
                    "val/precision": metrics["Precision"],
                    "val/recall": metrics["Recall"],
                    "val/f1_score": metrics["F1-Score"],
                    "val/False Positive Rate": metrics["False Positive Rate (FPR)"],
                    "val/False Negative Rate": metrics["False Negative Rate (FNR)"],
                    "val/val_loss": avg_loss,
                }
            )
        return avg_loss

    def train(self, resume_from_checkpoint=None):
        start_epoch = 0
        best_val_loss = float("inf")
        epochs_without_improvement = 0  # Track epochs without improvement

        # Load checkpoint if provided
        if resume_from_checkpoint:
            checkpoint = torch.load(resume_from_checkpoint, map_location=self.device)
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

            if "scheduler_state_dict" in checkpoint and self.scheduler:
                self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

            start_epoch = checkpoint["epoch"] + 1
            best_val_loss = checkpoint.get("best_val_loss", float("inf"))
            epochs_without_improvement = checkpoint.get("epochs_without_improvement", 0)

            print(f"Resuming training from epoch {start_epoch}")
        else:
            print("Start training!")
        for epoch in range(start_epoch, start_epoch + self.epochs):
            epoch_loss = self.train_one_epoch(epoch)
            val_loss = self.validate()

            print(
                f"Epoch [{epoch+1}/{self.epochs}], Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}"
            )

            # Check if validation loss improved
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_without_improvement = 0  # Reset counter
                print(
                    f"New best validation loss: {best_val_loss:.4f}. Saving checkpoint."
                )

                checkpoint = {
                    "epoch": epoch,
                    "model_state_dict": self.model.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "scheduler_state_dict": (
                        self.scheduler.state_dict() if self.scheduler else None
                    ),
                    "best_val_loss": best_val_loss,
                    "epochs_without_improvement": epochs_without_improvement,
                }
                torch.save(checkpoint, f"checkpoint_{epoch}.pth")
            else:
                epochs_without_improvement += 1
                print(
                    f"No improvement for {epochs_without_improvement}/{self.patience} epochs."
                )

            # Stop training if no improvement for `self.patience` epochs
            if epochs_without_improvement >= self.patience:
                print(
                    f"Validation loss hasn't improved for {self.patience} epochs. Stopping training early."
                )
                break

        print("Training complete.")


In [5]:
from torch import nn


class EmbeddingHead(torch.nn.Module):
    def __init__(self, in_features=768, embedding_dim=512, normalize=True, dropout_rate=0.1):
        super().__init__()
        self.normalize_output = normalize
        self.embedding = torch.nn.Linear(in_features, embedding_dim)
        self.bn = torch.nn.BatchNorm1d(embedding_dim)
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.bn(x)
        x = self.dropout(x)
        if self.normalize_output:
            x = torch.nn.functional.normalize(x, p=2, dim=1)
        return x

In [6]:
class NTXentLoss(torch.nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        """
        z_i: (batch_size, embedding_dim) - Embedding of original images
        z_j: (batch_size, embedding_dim) - Embedding of posittive images
        """
        batch_size = z_i.shape[0]

        # Normalize the embeddings
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)

        # Concatenate embeddings to form a (2N, embedding_dim) tensor
        z = torch.cat([z_i, z_j], dim=0)

        # Compute cosine similarity matrix (scaled by temperature)
        sim = torch.matmul(z, z.T) / self.temperature  # shape: (2N, 2N)

        # Mask out self-similarities by replacing diagonal with a very low value
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
        sim.masked_fill_(mask, -1e9)

        # For each sample i, the positive sample is at index (i + batch_size) mod (2N)
        positive_indices = (
            torch.arange(2 * batch_size, device=z.device) + batch_size
        ) % (2 * batch_size)
        positives = sim[torch.arange(2 * batch_size), positive_indices].unsqueeze(1)

        # Remove the positive column from sim to obtain negatives for each sample.
        all_indices = (
            torch.arange(2 * batch_size, device=z.device)
            .unsqueeze(0)
            .expand(2 * batch_size, -1)
        )
        pos_indices = positive_indices.unsqueeze(1)
        neg_mask = all_indices != pos_indices
        negatives = sim[neg_mask].view(2 * batch_size, -1)

        # Construct logits: first column is the positive, remaining are negatives.
        logits = torch.cat([positives, negatives], dim=1)

        # Labels: the positive is at index 0 for each sample.
        labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z.device)

        loss = F.cross_entropy(logits, labels)
        return loss


class NTXentLossHardNegatives(torch.nn.Module):
    def __init__(self, temperature=0.5, top_k_negatives=None):
        """
        Args:
            temperature: Scaling factor for similarities.
            top_k_negatives (int, optional): If not None, use only the top K
                                             hardest negatives per anchor.
                                             Defaults to None (use all negatives).
        """
        super(NTXentLossHardNegatives, self).__init__()
        self.temperature = temperature
        self.top_k_negatives = top_k_negatives
        # Sử dụng giá trị âm vô cùng nhỏ để đảm bảo không được chọn bởi topk
        self.mask_value = -float('inf')

    def forward(self, z_i, z_j):
        """
        z_i: (batch_size, embedding_dim) - Embedding of product images (or reviews)
        z_j: (batch_size, embedding_dim) - Embedding of corresponding review images (or products)
        """
        device = z_i.device
        batch_size = z_i.shape[0]

        # 1. Normalize the embeddings
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)

        # 2. Concatenate embeddings: [p1..pn, r1..rn]
        z = torch.cat([z_i, z_j], dim=0) # shape: (2N, embedding_dim)
        n_samples = z.shape[0] # Should be 2 * batch_size

        # 3. Compute cosine similarity matrix (scaled by temperature)
        sim = torch.matmul(z, z.T) / self.temperature  # shape: (2N, 2N)

        # 4. Create masks
        # Mask for self-similarities (diagonal)
        self_mask = torch.eye(n_samples, dtype=torch.bool, device=device)

        # Mask for positive pairs
        # Positive for pi (at index i) is ri (at index i + batch_size)
        # Positive for ri (at index i + batch_size) is pi (at index i)
        pos_indices = (torch.arange(n_samples, device=device) + batch_size) % n_samples
        # Create a full mask for positives for easier indexing later
        pos_mask = torch.zeros_like(self_mask, dtype=torch.bool)
        pos_mask[torch.arange(n_samples), pos_indices] = True

        # 5. Extract positive similarities
        # Use the pos_indices calculated earlier
        positives = sim[torch.arange(n_samples), pos_indices].unsqueeze(1) # shape: (2N, 1)

        # 6. Extract negative similarities and apply hard negative mining
        # Start with the full similarity matrix
        negatives_sim = sim.clone()

        # Mask out self-similarities and positives so they aren't selected as negatives
        negatives_sim.masked_fill_(self_mask, self.mask_value)
        negatives_sim.masked_fill_(pos_mask, self.mask_value)

        if self.top_k_negatives is not None:
            # Ensure k is not larger than the number of actual negatives available
            num_actual_negatives = n_samples - 2 # Exclude self and positive
            k_to_use = min(self.top_k_negatives, num_actual_negatives)

            if k_to_use > 0:
                # Select the top k highest similarity values (hardest negatives) for each row
                # topk returns values and indices, we only need values
                hard_negatives, _ = torch.topk(negatives_sim, k_to_use, dim=1, largest=True)
            else:
                # Handle edge case where k=0 or no negatives available (e.g., batch_size=1)
                hard_negatives = torch.empty((n_samples, 0), device=device)

            negatives = hard_negatives # Use only the hardest ones
        else:
            # If top_k_negatives is None, use all available negatives
            # We can extract them using the combined mask
            # This branch makes it equivalent to the original NTXentLoss logic,
            # but extracting via topk might be computationally similar anyway.
            # Let's extract them directly for clarity when not using topk.
            negative_mask = ~(self_mask | pos_mask)
            # Need to gather carefully, perhaps sticking with topk is simpler?
            # Let's stick to topk logic for consistency, setting k to max possible negatives.
            num_actual_negatives = n_samples - 2
            if num_actual_negatives > 0:
                negatives, _ = torch.topk(negatives_sim, num_actual_negatives, dim=1, largest=True)
            else:
                 negatives = torch.empty((n_samples, 0), device=device)


        # 7. Construct logits
        # First column is the positive similarity, subsequent columns are negative similarities
        logits = torch.cat([positives, negatives], dim=1) # shape: (2N, 1 + k_to_use or 1 + 2N-2)

        # 8. Create labels
        # The positive similarity is always at index 0
        labels = torch.zeros(n_samples, dtype=torch.long, device=device)

        # 9. Calculate cross-entropy loss
        loss = F.cross_entropy(logits, labels)

        return loss


In [None]:

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
import torch
from torchvision.models import swin_v2_t, Swin_V2_T_Weights

import wandb

data_path = "/kaggle/input/review-thesis-datasets"
val_data_path = "/kaggle/input/val-review-thesis-datasets"
# Hyperparameters
base_lr = 2e-5  # Learning rate ban đầu
num_epochs = 25
dataset_size = -1
val_dataset_size = -1
batch_size = 64
warmup_ratio = 0.1  # 10% epochs đầu là warmup
device = "cuda" if torch.cuda.is_available() else "cpu"

model = swin_v2_t(weights=Swin_V2_T_Weights.IMAGENET1K_V1)
model.head = EmbeddingHead(in_features=768, embedding_dim=256)
model = torch.nn.DataParallel(model)
model.to(device)

old_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # Resize tất cả ảnh về 224x224
        transforms.ToTensor(),  # Chuyển ảnh thành tensor
        transforms.RandomHorizontalFlip(p=0.5),  # Lật ngang ảnh với xác suất 50%
        transforms.ColorJitter(brightness=0.3),  # Điều chỉnh độ sáng (±30%)
        transforms.RandomPerspective(
            distortion_scale=0.5, p=0.5
        ),  # Biến dạng phối cảnh
        transforms.RandomRotation(degrees=30),  # Xoay ảnh trong khoảng ±30 độ
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # Resize tất cả ảnh về 224x224
        transforms.RandomApply([
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
        ], p=0.4),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)) # Kernel size needs to be odd
        ], p=0.4),
        transforms.ToTensor(),  # Chuyển ảnh thành tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
val_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),  # Resize tất cả ảnh về 224x224
        transforms.ToTensor(),  # Chuyển ảnh thành tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

print('Load train datasets!')
dataset = TripletDataset(data_path, transform=transform, limit=dataset_size)
print('Load val datasets!')
val_dataset = EvalTripletDataset(val_data_path, transform=val_transform, limit=val_dataset_size, type='val')


train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Khởi tạo Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=base_lr)
criterion = NTXentLossHardNegatives(temperature=0.25, top_k_negatives=64)
lr_scheduler = OneCycleLR(
    optimizer,
    max_lr=base_lr,
    epochs=num_epochs,
    steps_per_epoch=len(train_loader),
    pct_start=warmup_ratio
)

# Lặp qua dataloader để lấy batch
for images, labels in train_loader:
    print(
        f"Batch size: {images.shape}, Labels: {labels.shape}"
    )
    break  # Dừng sau batch đầu tiên

for images, labels, negative in val_loader:
    print(
        f"Batch size: {images.shape}, Labels: {labels.shape}, Negative: {negative.shape}"
    )
    break  # Dừng sau batch đầu tiên

Downloading: "https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth" to /root/.cache/torch/hub/checkpoints/swin_v2_t-b137f0e2.pth



  0%|          | 0.00/109M [00:00<?, ?B/s]


  3%|▎         | 3.12M/109M [00:00<00:03, 32.7MB/s]


 15%|█▌        | 16.5M/109M [00:00<00:01, 83.4MB/s]


 30%|███       | 32.9M/109M [00:00<00:00, 101MB/s] 


 45%|████▌     | 49.1M/109M [00:00<00:00, 124MB/s]


 56%|█████▋    | 61.1M/109M [00:00<00:00, 76.4MB/s]


 65%|██████▍   | 70.4M/109M [00:00<00:00, 68.9MB/s]


 75%|███████▌  | 82.0M/109M [00:01<00:00, 69.7MB/s]


 91%|█████████ | 98.4M/109M [00:01<00:00, 84.4MB/s]


100%|██████████| 109M/109M [00:01<00:00, 87.6MB/s] 




Load train datasets!
/kaggle/input/review-thesis-datasets/part7/part7


/kaggle/input/review-thesis-datasets/part12/part12


/kaggle/input/review-thesis-datasets/part1/part1


/kaggle/input/review-thesis-datasets/part10/part10


/kaggle/input/review-thesis-datasets/part6/part6


/kaggle/input/review-thesis-datasets/part13/part13


/kaggle/input/review-thesis-datasets/part2/part2


/kaggle/input/review-thesis-datasets/part11/part11


/kaggle/input/review-thesis-datasets/part8/part8


/kaggle/input/review-thesis-datasets/part3/part3


/kaggle/input/review-thesis-datasets/part4/part4


/kaggle/input/review-thesis-datasets/part5/part5


Load val datasets!
/kaggle/input/val-review-thesis-datasets/part9


Batch size: torch.Size([64, 3, 224, 224]), Labels: torch.Size([64, 3, 224, 224])


Batch size: torch.Size([64, 3, 224, 224]), Labels: torch.Size([64, 3, 224, 224]), Negative: torch.Size([64, 3, 224, 224])


In [8]:
print(len(train_loader), len(val_loader))

787 66


In [9]:
wandb.login(key="c436a4917c43e09e30b67b919bd06e7bf7b0c10d")
wandb.init(
    project="review thesis project",
    name="true_18_NTXentLoss_hardnegative_swinv2",
    # id="r2k837mu",
    # resume="allow"
)

torch.cuda.empty_cache()
trainer = SpamDetectorTrainer(
    model=model,
    criterion=criterion,
    train_loader=train_loader,
    valid_loader=val_loader,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    device=device,
    epochs=num_epochs,
    log_writer=wandb,
    patience=5
)

trainer.train()

wandb.finish()




[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Currently logged in as: [33mlong02042003[0m ([33mmosque[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Tracking run with wandb version 0.19.6


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250428_090847-ahj4ksll[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mtrue_18_NTXentLoss_hardnegative_swinv2[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/mosque/review%20thesis%20project[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/mosque/review%20thesis%20project/runs/ahj4ksll[0m


Start training!
















Epoch [1/25], Train Loss: 3.6137, Val Loss: 2.9972
New best validation loss: 2.9972. Saving checkpoint.
















Epoch [2/25], Train Loss: 3.2086, Val Loss: 2.7372
New best validation loss: 2.7372. Saving checkpoint.
















Epoch [3/25], Train Loss: 3.0115, Val Loss: 2.6426
New best validation loss: 2.6426. Saving checkpoint.
















Epoch [4/25], Train Loss: 2.9068, Val Loss: 2.5838
New best validation loss: 2.5838. Saving checkpoint.
















Epoch [5/25], Train Loss: 2.8318, Val Loss: 2.5473
New best validation loss: 2.5473. Saving checkpoint.
















Epoch [6/25], Train Loss: 2.7794, Val Loss: 2.5140
New best validation loss: 2.5140. Saving checkpoint.
















Epoch [7/25], Train Loss: 2.7303, Val Loss: 2.4917
New best validation loss: 2.4917. Saving checkpoint.














Epoch [8/25], Train Loss: 2.6882, Val Loss: 2.4760
New best validation loss: 2.4760. Saving checkpoint.
















Epoch [9/25], Train Loss: 2.6526, Val Loss: 2.4590
New best validation loss: 2.4590. Saving checkpoint.
















Epoch [10/25], Train Loss: 2.6187, Val Loss: 2.4490
New best validation loss: 2.4490. Saving checkpoint.
















Epoch [11/25], Train Loss: 2.5921, Val Loss: 2.4378
New best validation loss: 2.4378. Saving checkpoint.














Epoch [12/25], Train Loss: 2.5669, Val Loss: 2.4360
New best validation loss: 2.4360. Saving checkpoint.
















Epoch [13/25], Train Loss: 2.5411, Val Loss: 2.4297
New best validation loss: 2.4297. Saving checkpoint.
















Epoch [14/25], Train Loss: 2.5210, Val Loss: 2.4261
New best validation loss: 2.4261. Saving checkpoint.














Epoch [15/25], Train Loss: 2.5026, Val Loss: 2.4212
New best validation loss: 2.4212. Saving checkpoint.
















Epoch [16/25], Train Loss: 2.4873, Val Loss: 2.4165
New best validation loss: 2.4165. Saving checkpoint.
















Epoch [17/25], Train Loss: 2.4756, Val Loss: 2.4156
New best validation loss: 2.4156. Saving checkpoint.
















Epoch [18/25], Train Loss: 2.4628, Val Loss: 2.4115
New best validation loss: 2.4115. Saving checkpoint.
















Epoch [19/25], Train Loss: 2.4534, Val Loss: 2.4106
New best validation loss: 2.4106. Saving checkpoint.














Epoch [20/25], Train Loss: 2.4481, Val Loss: 2.4080
New best validation loss: 2.4080. Saving checkpoint.














