In [2]:

!pip install torchvision transformers datasets scikit-learn --quiet

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score
import os

In [3]:
# --- Self-Attention Layer ---
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key   = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, W, H = x.size()
        proj_query = self.query(x).view(B, -1, W*H).permute(0, 2, 1)
        proj_key   = self.key(x).view(B, -1, W*H)
        energy     = torch.bmm(proj_query, proj_key)
        attention  = torch.softmax(energy, dim=-1)
        proj_value = self.value(x).view(B, -1, W*H)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, W, H)
        return self.gamma * out + x

In [4]:
# --- Generator ---
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super().__init__()
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 128, 4, 1, 0),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            SelfAttention(64),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

In [5]:
# --- Discriminator ---
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            SelfAttention(64),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.LeakyReLU(0.2),
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

In [6]:
# --- Load Dataset ---
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
dataset = torchvision.datasets.Flowers102(root="./data", split='train', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True)


In [7]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen = Generator().to(device)
disc = Discriminator().to(device)
opt_gen = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_disc = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion = nn.BCELoss()


epochs = 10
z_dim = 100
fixed_noise = torch.randn(16, z_dim, 1, 1).to(device)
os.makedirs("generated_images", exist_ok=True)

for epoch in range(epochs):
    for batch in loader:
        imgs, _ = batch
        imgs = imgs.to(device)
        batch_size = imgs.size(0)

        # Labels
        real = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)

        # --- Train Discriminator ---
        noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
        fake_imgs = gen(noise)
        disc_real = disc(imgs)
        disc_fake = disc(fake_imgs.detach())
        loss_disc = (criterion(disc_real, real) + criterion(disc_fake, fake)) / 2

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # --- Train Generator ---
        output = disc(fake_imgs)
        loss_gen = criterion(output, real)

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

    # --- Accuracy Calculation ---
    with torch.no_grad():
        pred_real = (disc_real > 0.5).float()
        pred_fake = (disc_fake < 0.5).float()
        correct = torch.sum(pred_real) + torch.sum(pred_fake)
        total = 2 * batch_size
        disc_acc = (correct / total).item() * 100

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}, Disc Acc: {disc_acc:.2f}%")
    save_image(gen(fixed_noise), f"generated_images/epoch_{epoch+1}.png", normalize=True)


torch.save(gen.state_dict(), "self_attention_generator.pth")
torch.save(disc.state_dict(), "self_attention_discriminator.pth")

Epoch [1/10] Loss D: 0.5536, Loss G: 0.7974, Disc Acc: 99.17%
Epoch [2/10] Loss D: 0.4369, Loss G: 1.2314, Disc Acc: 95.00%
Epoch [3/10] Loss D: 0.4943, Loss G: 0.9314, Disc Acc: 98.33%
Epoch [4/10] Loss D: 0.3436, Loss G: 1.5118, Disc Acc: 96.67%
Epoch [5/10] Loss D: 0.5808, Loss G: 0.9315, Disc Acc: 51.67%
Epoch [6/10] Loss D: 0.6192, Loss G: 0.9935, Disc Acc: 85.83%
Epoch [7/10] Loss D: 0.6688, Loss G: 1.0933, Disc Acc: 75.00%
Epoch [8/10] Loss D: 0.4398, Loss G: 1.3019, Disc Acc: 94.17%
Epoch [9/10] Loss D: 0.3468, Loss G: 1.4250, Disc Acc: 95.00%
Epoch [10/10] Loss D: 0.3338, Loss G: 1.5490, Disc Acc: 90.00%


In [8]:
# --- Evaluation ---
def evaluate_model(generator, discriminator, num_samples=500):
    generator.eval()
    discriminator.eval()
    real_labels = []
    pred_labels = []

    for i, (real_imgs, _) in enumerate(loader):
        if i * 64 > num_samples:
            break
        real_imgs = real_imgs.to(device)
        fake_noise = torch.randn(real_imgs.size(0), z_dim, 1, 1).to(device)
        fake_imgs = generator(fake_noise)

        real_preds = discriminator(real_imgs).detach().cpu().numpy().round()
        fake_preds = discriminator(fake_imgs).detach().cpu().numpy().round()

        real_labels += [1] * len(real_preds)
        pred_labels += list(real_preds.flatten())

        real_labels += [0] * len(fake_preds)
        pred_labels += list(fake_preds.flatten())

    cm = confusion_matrix(real_labels, pred_labels)
    precision = precision_score(real_labels, pred_labels)
    recall = recall_score(real_labels, pred_labels)
    accuracy = accuracy_score(real_labels, pred_labels)
    return cm, precision, recall, accuracy


conf_matrix, prec, rec, acc = evaluate_model(gen, disc)
print("\nFinal Evaluation Metrics:")
print("Confusion Matrix:\n", conf_matrix)
print("Precision: {:.4f}".format(prec))
print("Recall: {:.4f}".format(rec))
print("Accuracy: {:.4f}".format(acc))


Final Evaluation Metrics:
Confusion Matrix:
 [[512   0]
 [ 93 419]]
Precision: 1.0000
Recall: 0.8184
Accuracy: 0.9092
