Group: 

Members: 

In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm 

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# ===============================
# Self-Defined Dataloader
# ===============================
class data_loader(Dataset):
    def __init__(self, data_dir):

        real = os.path.join(data_dir, '0_real')
        fake = os.path.join(data_dir, '1_fake')

        file_names_real = os.listdir(real)
        file_names_fake = os.listdir(fake)

        self.full_filenames_real = [os.path.join(real, f) for f in file_names_real]
        self.full_filenames_fake = [os.path.join(fake, f) for f in file_names_fake]
        self.full_filenames = self.full_filenames_real + self.full_filenames_fake

        self.labels_real = [0 for _ in file_names_real]
        self.labels_fake = [1 for _ in file_names_fake]
        self.labels = self.labels_real + self.labels_fake

        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=(64, 64)),
            transforms.ToTensor(),
        ])
        
    def __len__(self):
        return len(self.full_filenames)

    def __getitem__(self, idx):
        image = Image.open(self.full_filenames[idx]).convert("RGB")
        image = self.transform(image)
        label = self.labels[idx]
        return image, label


In [None]:
# ===============================
# Neural NetWork
# ===============================
class CNN(nn.Module):
    def __init__(self, pretrained=True, freeze_backbone=True, dropout=0.3):
        super(Model, self).__init__()

        # === ViT-Base Patch16 224 ===
        self.vit = create_model('vit_base_patch16_224', pretrained=pretrained, num_classes=0)  # 768-dim
        # === Swin-Base Patch4 Window7 224 ===
        self.swin = create_model('swin_base_patch4_window7_224', pretrained=pretrained, num_classes=0)  # 1024-dim

        # Freeze backbones (recommended for AIGC detection with limited data)
        if freeze_backbone:
            for param in self.vit.parameters():
                param.requires_grad = False
            for param in self.swin.parameters():
                param.requires_grad = False

        # Fusion MLP: 768 + 1024 = 1792 → 512 → 128 → 2
        self.fusion = nn.Sequential(
            nn.Linear(768 + 1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),

            nn.Linear(128, 2)  # Exactly 2 classes: real vs synthetic
        )

    def forward(self, x):
        # Extract features
        vit_feat = self.vit(x)           # [B, 768]
        swin_feat = self.swin(x)         # [B, 1024]

        # For Swin, forward_features returns [B, H*W, C] → global avg pool if needed
        if len(swin_feat.shape) == 3:
            swin_feat = swin_feat.mean(1)  # [B, 1024]

        # Concatenate
        combined = torch.cat([vit_feat, swin_feat], dim=1)  # [B, 1792]

        # Final classification
        out = self.fusion(combined)
        return out

In [None]:
# ===============================
# Train-Validate
# ===============================
def main():
    data_root = "data"
    batch_size = 32
    epochs = 10
    lr = 1e-4

    train_dataset = data_loader(os.path.join(data_root, "train"))
    val_dataset = data_loader(os.path.join(data_root, "val"))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        # ---- Train ----
        model.train()
        total_loss, total_correct, total = 0, 0, 0
        loop = tqdm(train_loader, total=len(train_loader), desc=f"Epoch [{epoch+1}/{epochs}]")

        for imgs, labels in loop:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            total_correct += (outputs.argmax(1) == labels).sum().item()
            total += imgs.size(0)

        train_loss = total_loss / total
        train_acc = total_correct / total

        # ---- Validate ----
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)****
                outputs = model(imgs)
                preds = outputs.argmax(1)
                val_correct += (preds == labels).sum().item()
                val_total += imgs.size(0)
        val_acc = val_correct / val_total

        print(f"Epoch [{epoch+1}/{epochs}] "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    # Save Model
    torch.save(model.state_dict(), "model.pth")
    print("Model saved")



In [None]:
if __name__ == "__main__":
    main()
