Group: 

Members: 

In [10]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from timm import create_model
import shutil
from pathlib import Path
import random
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import numpy as np
import csv, json

device = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
!rm -rf data/
!rm -rf test/
!rm -rf __MACOSX/
!unzip data.zip

Archive:  data.zip
   creating: data
  inflating: __MACOSX/._data         
  inflating: data/.DS_Store          
  inflating: __MACOSX/data/._.DS_Store  
   creating: data/train
   creating: data/train/1_fake
  inflating: __MACOSX/data/train/._1_fake  
   creating: data/train/0_real
  inflating: __MACOSX/data/train/._0_real  
  inflating: data/train/1_fake/000000512793.jpg  
  inflating: __MACOSX/data/train/1_fake/._000000512793.jpg  
  inflating: data/train/1_fake/000000298197.jpg  
  inflating: __MACOSX/data/train/1_fake/._000000298197.jpg  
  inflating: data/train/1_fake/000000262985.jpg  
  inflating: __MACOSX/data/train/1_fake/._000000262985.jpg  
  inflating: data/train/1_fake/000000363331.jpg  
  inflating: __MACOSX/data/train/1_fake/._000000363331.jpg  
  inflating: data/train/1_fake/000000057992.jpg  
  inflating: __MACOSX/data/train/1_fake/._000000057992.jpg  
  inflating: data/train/1_fake/000000520047.jpg  
  inflating: __MACOSX/data/train/1_fake/._000000520047.jpg  
  infl

In [12]:
import os
import shutil
from pathlib import Path
import random
import argparse

def split_train_to_val(source_dir, val_ratio=0.2, seed=42):
    source_dir = Path(source_dir)
    val_dir = source_dir.parent / "val"
    
    random.seed(seed)
    
    # Define class subdirectories
    classes = ["0_real", "1_fake"]
    
    for class_name in classes:
        train_class_dir = source_dir / class_name
        val_class_dir = val_dir / class_name
        
        if not train_class_dir.exists():
            print(f"Warning: {train_class_dir} does not exist. Skipping.")
            continue
            
        # Create validation directory
        val_class_dir.mkdir(parents=True, exist_ok=True)
        
        # Get all image files
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
        files = [f for f in train_class_dir.iterdir() 
                if f.suffix.lower() in image_extensions and f.is_file()]
        
        if len(files) == 0:
            print(f"No images found in {train_class_dir}")
            continue
            
        # Calculate number to move
        num_to_move = max(1, int(len(files) * val_ratio))  # at least 1 image
        print(f"Moving {num_to_move}/{len(files)} images from {class_name} to validation")
        
        # Randomly select files
        files_to_move = random.sample(files, num_to_move)
        
        # Move them
        for file_path in files_to_move:
            dest_path = val_class_dir / file_path.name
            shutil.move(str(file_path), str(dest_path))
            # print(f"Moved: {file_path.name} → {dest_path}")

    print(f"\nDone! Validation set created at: {val_dir}")
    
split_train_to_val("data/train", val_ratio=0.15, seed=42)

Moving 3750/25000 images from 0_real to validation
Moving 3750/25000 images from 1_fake to validation

Done! Validation set created at: data/val


In [13]:
class data_loader(Dataset):
    def __init__(self, data_dir):

        real = os.path.join(data_dir, '0_real')
        fake = os.path.join(data_dir, '1_fake')

        file_names_real = os.listdir(real)
        file_names_fake = os.listdir(fake)

        self.full_filenames_real = [os.path.join(real, f) for f in file_names_real]
        self.full_filenames_fake = [os.path.join(fake, f) for f in file_names_fake]
        self.full_filenames = self.full_filenames_real + self.full_filenames_fake

        self.labels_real = [0 for _ in file_names_real]
        self.labels_fake = [1 for _ in file_names_fake]
        self.labels = self.labels_real + self.labels_fake

        self.transform_original = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])

        self.transform_aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(30),
            transforms.RandAugment(num_ops=2, magnitude=9),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.ToTensor(),
            transforms.RandomErasing(p=0.25, scale=(0.02, 0.15), ratio=(0.3, 3.3), value='random'),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])

    def __len__(self):
        return len(self.full_filenames)

    def __getitem__(self, idx):
        image = Image.open(self.full_filenames[idx]).convert("RGB")
        image_aug = self.transform_aug(image)
        image_original = self.transform_original(image)
        label = self.labels[idx]
        return image_original, image_aug, label


In [14]:
# ===============================
# Neural NetWork
# ===============================
class CNN(nn.Module):
    def __init__(self, pretrained=True, freeze_backbone=True, dropout=0.3):
        super(CNN, self).__init__()
        # === Swin-B ===
        self.swin = create_model('swin_base_patch4_window7_224', pretrained=pretrained, num_classes=0)  # 1024-dim

        # Freeze backbones (recommended for AIGC detection with limited data)
        if freeze_backbone:
            for param in self.swin.parameters():
                param.requires_grad = False

        self.fusion = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),

            nn.Linear(128, 2)  # Exactly 2 classes: real vs synthetic
        )

    def forward(self, x):
        # Extract features
        swin_feat = self.swin(x)         # [B, 1024]

        # For Swin, forward_features returns [B, H*W, C] → global avg pool if needed
        if len(swin_feat.shape) == 3:
            swin_feat = swin_feat.mean(1)  # [B, 1024]

        # Final classification
        out = self.fusion(swin_feat)
        return out

In [18]:
# ===============================
# Train-Validate
# ===============================
def main():
    data_root = "data"
    batch_size = 32 # Smaller batch = better generalization with dual-view
    epochs_list = [5, 10, 15]
    lr = 3e-5 # Lower LR works much better when unfreezing layers

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on {device}")

    # For each epochs, reset all settings
    for epochs in epochs_list:
        # ===============================
        # Metrics Logging Setup
        # ===============================
        metrics_log = {
            "train_loss": [], "train_acc": [], "train_prec": [], "train_rec": [], "train_f1": [],
            "val_acc": [], "val_prec": [], "val_rec": [], "val_f1": []
        }

        os.makedirs("logs", exist_ok=True)
        csv_path  = f"logs/training_metrics_{epochs}e.csv"
        json_path = f"logs/training_metrics_{epochs}e.json"
        # ===============================
        # Dataset & Dataloader
        # ===============================
        train_dataset = data_loader(os.path.join(data_root, "train"))
        val_dataset = data_loader(os.path.join(data_root, "val"))

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                num_workers=4 if torch.cuda.is_available() else 0, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False,
                                num_workers=4 if torch.cuda.is_available() else 0, pin_memory=True)

        # ===============================
        # Model + Optimizer + Loss
        # ===============================
        model = CNN().to(device)

        # Unfreeze last few blocks
        for p in model.swin.layers[-2:].parameters():
            p.requires_grad = True

        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            weight_decay=1e-4
        )
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
        best_val_f1 = 0.0
        best_val_acc = 0.0

        for epoch in range(epochs):
            # ------------------- Training -------------------
            model.train()
            total_loss = 0.0
            train_all_preds = []
            train_all_labels = []

            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1:02d}/{epochs} [Train]")

            for img_clean, img_aug, labels in pbar:
                img_clean = img_clean.to(device, non_blocking=True)
                img_aug = img_aug.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                optimizer.zero_grad()
                logits_clean = model(img_clean)
                logits_aug   = model(img_aug)

                # Proper dual-view ensembling
                logits = (logits_clean + logits_aug) / 2.0
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                preds = logits.argmax(dim=1)

                train_all_preds.extend(preds.cpu().numpy())
                train_all_labels.extend(labels.cpu().numpy())

                pbar.set_postfix({"Loss": f"{total_loss/(epoch+1):.4f}"})

            # Training metrics
            train_acc  = accuracy_score(train_all_labels, train_all_preds)
            train_prec, train_rec, train_f1, _ = precision_recall_fscore_support(
                train_all_labels, train_all_preds, average='macro', zero_division=0
            )
            avg_train_loss = total_loss / len(train_loader)

            # ------------------- Validation -------------------
            model.eval()
            val_all_preds = []
            val_all_labels = []

            with torch.no_grad():
                for img_clean, _, labels in val_loader:
                    img_clean = img_clean.to(device)
                    labels = labels.to(device)
                    logits = model(img_clean)
                    preds = logits.argmax(dim=1)
                    val_all_preds.extend(preds.cpu().numpy())
                    val_all_labels.extend(labels.cpu().numpy())

            val_acc = accuracy_score(val_all_labels, val_all_preds)
            val_prec, val_rec, val_f1, _ = precision_recall_fscore_support(
                val_all_labels, val_all_preds, average='macro', zero_division=0
            )

            scheduler.step()

            # Save best model (by macro F1)
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                best_val_acc = val_acc
                torch.save(model.state_dict(), "model_best.pth")
                print(f"New best model saved! Val F1: {val_f1:.4f} | Val Acc: {val_acc:.4f}")

            # ===============================
            # Log metrics for this epoch
            # ===============================
            metrics_log["train_loss"].append(avg_train_loss)
            metrics_log["train_acc"].append(train_acc)
            metrics_log["train_prec"].append(train_prec)
            metrics_log["train_rec"].append(train_rec)
            metrics_log["train_f1"].append(train_f1)
            metrics_log["val_acc"].append(val_acc)
            metrics_log["val_prec"].append(val_prec)
            metrics_log["val_rec"].append(val_rec)
            metrics_log["val_f1"].append(val_f1)

            # Print epoch summary
            print(f"\n=== Epoch {epoch+1:02d}/{epochs} ===")
            print(f"Train → Loss: {avg_train_loss:.4f} | Acc: {train_acc:.4f} | P: {train_prec:.4f} | R: {train_rec:.4f} | F1: {train_f1:.4f}")
            print(f"Val   → Acc: {val_acc:.4f} | P: {val_prec:.4f} | R: {val_rec:.4f} | F1: {val_f1:.4f}")
            print(f"Best Val → Acc: {best_val_acc:.4f} | F1: {best_val_f1:.4f}\n")

            # ===============================
            # Save Metrics to CSV & JSON
            # ===============================
            # CSV
            with open(csv_path, mode='w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(["Epoch", "Train_Loss", "Train_Acc", "Train_Prec", "Train_Rec", "Train_F1",
                                "Val_Acc", "Val_Prec", "Val_Rec", "Val_F1"])
                for i in range(epochs):
                    writer.writerow([
                        i+1,
                        f"{metrics_log['train_loss'][i]:.4f}",
                        f"{metrics_log['train_acc'][i]:.4f}",
                        f"{metrics_log['train_prec'][i]:.4f}",
                        f"{metrics_log['train_rec'][i]:.4f}",
                        f"{metrics_log['train_f1'][i]:.4f}",
                        f"{metrics_log['val_acc'][i]:.4f}",
                        f"{metrics_log['val_prec'][i]:.4f}",
                        f"{metrics_log['val_rec'][i]:.4f}",
                        f"{metrics_log['val_f1'][i]:.4f}",
                    ])

            # JSON (extra backup + easy to load later)
            log_to_save = {k: [f"{v:.6f}" if isinstance(v, float) else v for v in vals] 
                        for k, vals in metrics_log.items()}
            log_to_save["best_val_accuracy"] = f"{best_val_acc:.6f}"
            log_to_save["best_val_f1"] = f"{best_val_f1:.6f}"
            log_to_save["total_epochs"] = epochs

            with open(json_path, 'w') as f:
                json.dump(log_to_save, f, indent=2)

            print(f"All metrics saved!")
            print(f"   → CSV : {csv_path}")
            print(f"   → JSON: {json_path}")
            print(f"Final Best Val Accuracy: {best_val_acc:.4f} | Best Val F1: {best_val_f1:.4f}\n{'='*60}\n")

        # Final save per epoch count
        torch.save(model.state_dict(), f"drive/MyDrive/CS4487/ViTB_Swin_{epochs}e.pth")
        print(f"Training with {epochs} epochs finished!")
        print(f"Final Best Val Accuracy: {best_val_acc:.4f} | Best Val F1: {best_val_f1:.4f}\n")

In [19]:
if __name__ == "__main__":
    main()

Training on cpu


Epoch 01/5 [Train]:   0%|          | 3/1328 [00:47<5:51:56, 15.94s/it, Loss=2.1338]


KeyboardInterrupt: 