In [None]:
import torch
import torchvision

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
# Transform: resize and normalize
transform = transforms.Compose([
    transforms.Resize((128, 128)),       # Resize to smaller size (optional)
    transforms.ToTensor(),               # Convert image to PyTorch tensor
    transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])



train_loader = torch.utils.data.DataLoader(
  datasets.CelebA(root='./data', split='train', download=True,
                  transform=transform), batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
  datasets.CelebA(root='./data', split='test', download=True,
                  transform=transform), batch_size=1000, shuffle=True)
# Check one batch
images, labels = next(iter(train_loader))
print("Image batch shape:", images.shape)
print("Label shape:", labels.shape)

Downloading...
From (original): https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM
From (redirected): https://drive.usercontent.google.com/download?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM&confirm=t&uuid=d876c8dd-0c34-4b3d-bd63-7a9bc4e3cbdb
To: /content/data/celeba/img_align_celeba.zip
100%|██████████| 1.44G/1.44G [00:18<00:00, 76.6MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pblRyaVFSWGxPY0U
To: /content/data/celeba/list_attr_celeba.txt
100%|██████████| 26.7M/26.7M [00:00<00:00, 73.2MB/s]
Downloading...
From: https://drive.google.com/uc?id=1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS
To: /content/data/celeba/identity_CelebA.txt
100%|██████████| 3.42M/3.42M [00:00<00:00, 26.3MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pbThiMVRxWXZ4dU0
To: /content/data/celeba/list_bbox_celeba.txt
100%|██████████| 6.08M/6.08M [00:00<00:00, 86.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=0B7EVK8r0v71pd0FJY3Blby1HUTQ
To: /content/data/celeba/list_landm

Image batch shape: torch.Size([64, 3, 128, 128])
Label shape: torch.Size([64, 40])


In [None]:
train_loader.dataset.attr_names

['5_o_Clock_Shadow',
 'Arched_Eyebrows',
 'Attractive',
 'Bags_Under_Eyes',
 'Bald',
 'Bangs',
 'Big_Lips',
 'Big_Nose',
 'Black_Hair',
 'Blond_Hair',
 'Blurry',
 'Brown_Hair',
 'Bushy_Eyebrows',
 'Chubby',
 'Double_Chin',
 'Eyeglasses',
 'Goatee',
 'Gray_Hair',
 'Heavy_Makeup',
 'High_Cheekbones',
 'Male',
 'Mouth_Slightly_Open',
 'Mustache',
 'Narrow_Eyes',
 'No_Beard',
 'Oval_Face',
 'Pale_Skin',
 'Pointy_Nose',
 'Receding_Hairline',
 'Rosy_Cheeks',
 'Sideburns',
 'Smiling',
 'Straight_Hair',
 'Wavy_Hair',
 'Wearing_Earrings',
 'Wearing_Hat',
 'Wearing_Lipstick',
 'Wearing_Necklace',
 'Wearing_Necktie',
 'Young',
 '']

In [None]:
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

In [None]:
class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.layer1 = nn.Sequential(
          nn.Conv2d(3, 16, kernel_size=3),
          nn.BatchNorm2d(16),
          nn.ReLU(),
          nn.MaxPool2d(kernel_size=2, stride=2)
      )
      self.layer2 = nn.Sequential(
          nn.Conv2d(16, 16, kernel_size=3),
          nn.BatchNorm2d(16),
          nn.ReLU(),
          nn.MaxPool2d(kernel_size=2, stride=2)
      )
      self.layer3 = nn.Sequential(
          nn.Conv2d(16, 64, kernel_size=3),
          nn.BatchNorm2d(64),
          nn.ReLU(),
      )

      self.fc = nn.Sequential(
          nn.Linear(64 * 28 * 28, 128),
          nn.ReLU(),
          nn.Linear(128, 128),
          nn.ReLU(),
          nn.Linear(128, 40)
      )

    def forward(self, x):
      x = self.layer1(x)
      x = self.layer2(x)
      x = self.layer3(x)
      x = x.view(x.size(0), -1)
      x = self.fc(x)
      return x


In [None]:
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.BCEWithLogitsLoss()

In [None]:
def train(epoch):
  model.train()
  for inputs, targets in tqdm(train_loader):
    optimizer.zero_grad()
    outputs = model(inputs)

    targets = targets.to(torch.float32)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()


In [None]:
for epoch in range(3):
  train(epoch)

In [None]:
def test():
    model.eval()
    correct_per_attribute = torch.zeros(40)
    total_samples = 0

    with torch.no_grad():
        for inputs, targets in test_loader:

            logits = model(inputs)  # shape: (batch_size, 40)
            probs = torch.sigmoid(logits)  # apply sigmoid for multi-label
            preds = (probs > 0.5).float()  # thresholding

            correct_per_attribute += (preds == targets).sum(dim=0)
            total_samples += targets.size(0)

    # Accuracy per attribute
    acc_per_attribute = correct_per_attribute / total_samples
    avg_acc = acc_per_attribute.mean().item()

    print(f"Average accuracy across 40 attributes: {avg_acc * 100:.2f}%")

    # Optional: print individual attribute accuracies
    for i, acc in enumerate(acc_per_attribute):
        print(f"Attribute {i+1}: Accuracy = {acc.item() * 100:.2f}%")


In [None]:
test()

Average accuracy across 40 attributes: 89.47%
Attribute 1: Accuracy = 92.36%
Attribute 2: Accuracy = 79.73%
Attribute 3: Accuracy = 80.54%
Attribute 4: Accuracy = 83.21%
Attribute 5: Accuracy = 98.01%
Attribute 6: Accuracy = 94.54%
Attribute 7: Accuracy = 70.13%
Attribute 8: Accuracy = 82.29%
Attribute 9: Accuracy = 86.64%
Attribute 10: Accuracy = 95.02%
Attribute 11: Accuracy = 95.37%
Attribute 12: Accuracy = 86.69%
Attribute 13: Accuracy = 91.69%
Attribute 14: Accuracy = 94.89%
Attribute 15: Accuracy = 95.78%
Attribute 16: Accuracy = 98.92%
Attribute 17: Accuracy = 96.49%
Attribute 18: Accuracy = 97.67%
Attribute 19: Accuracy = 89.86%
Attribute 20: Accuracy = 85.60%
Attribute 21: Accuracy = 96.48%
Attribute 22: Accuracy = 92.11%
Attribute 23: Accuracy = 96.27%
Attribute 24: Accuracy = 86.36%
Attribute 25: Accuracy = 94.07%
Attribute 26: Accuracy = 73.69%
Attribute 27: Accuracy = 96.68%
Attribute 28: Accuracy = 75.09%
Attribute 29: Accuracy = 92.51%
Attribute 30: Accuracy = 94.57%
Att