In [1]:
import torch

In [2]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
device

device(type='cuda')

In [4]:
import os
import glob
import numpy as np
from PIL import Image 
import torch
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch import nn
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import torchvision
import torchvision.datasets as dataset
import torchvision.transforms as transforms

In [5]:
import torch
from torchvision import transforms
from PIL import Image

class celeba(Dataset):
    def __init__(self, data_path=None, label_path=None, Z_vals=None):
        self.data_path = data_path
        self.label_path = label_path
        self.Z_vals = Z_vals

        self.transform = transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, idx):
        # images
        image_set = Image.open(self.data_path[idx])
        image_tensor = self.transform(image_set)
        
        # bald labels
        image_label = torch.Tensor(self.label_path[idx])
        
        # male labels
        z = torch.Tensor(self.Z_vals[idx])
        
        # Move to GPU
        image_tensor = image_tensor.to('cuda')
        image_label = image_label.to('cuda')
        Z = z.to('cuda')
        
        return image_tensor, image_label, Z



In [6]:
def main():
    
    
    file_name = "/kaggle/input/totaldata/Total/Total/"

    data_path = sorted(glob.glob(file_name + '*.jpg'))
    #print(len(data_path))

    label_path = "/kaggle/input/totaldata/target_values.txt"
    label_list = open(label_path).readlines()
    data_label = []
    for i in range(len(label_list)):
        data_label.append(label_list[i].split( ))
    data_label = [[int(x[0])] for x in data_label]
    #print(data_label)
    
    Z_path = "/kaggle/input/totaldata/z_vals.txt"
    Z_list = open(Z_path).readlines()
    Z_label = []
    for z in range(len(Z_list)):
        Z_label.append(Z_list[z].split( ))
    Z_label = [[int(y[0])] for y in Z_label]
    #print(Z_label)
    
    dataset = celeba(data_path, data_label, Z_label)

    trainloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
    
    return trainloader

In [7]:
dataloader = main()

In [8]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x798fe6c5fa90>

In [9]:
len(dataloader)*64

5696

In [10]:
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 158MB/s]


In [11]:
for params in model.parameters():
    params.requires_grad = False

In [12]:
in_features = model.fc.in_features
out_features = 1

class Classifier(nn.Module):
    def __init__(self, in_features, out_features):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(in_features, out_features)

    def forward(self, x):
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

model.fc = Classifier(in_features=in_features, out_features=out_features)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [13]:
for params in model.fc.parameters():
    params.requires_grad = True

In [14]:
model.to('cuda')

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [15]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [16]:
count_parameters(model)

513

In [17]:
class Adversary(nn.Module):
    def __init__(self):
        super(Adversary, self).__init__()
        self.fc = nn.Linear(1, 64)
        self.fc2 = nn.Linear(64, 1)
        
    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        return x


In [18]:
adversary = Adversary()

In [19]:
predictor_criterion = nn.BCELoss()
adversary_criterion = nn.BCELoss()

In [20]:
predictor_optimizer = optim.Adam(model.parameters(), lr=1e-3)
adversary_optimizer = optim.Adam(adversary.parameters(), lr=1e-3)

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
adversary.to(device)
alpha = 0.1

for epoch in range(20):
    predictor_running_loss = 0
    adversary_running_loss = 0

    model.train()
    adversary.train()

    for data in dataloader:

        images, labels, z = data

        images = images.to(device)
        labels = labels.to(device) 
        z = z.to(device)  

        # Reset gradients
        #predictor_optimizer.zero_grad()
        #adversary_optimizer.zero_grad()

        # Forward pass
        predictor_outputs = model(images)
        adversary_outputs = adversary(predictor_outputs)

        # Compute losses
        predictor_loss = predictor_criterion(predictor_outputs, labels)
        adversary_loss = adversary_criterion(adversary_outputs, z)

        combined_loss = (predictor_loss) + (predictor_loss/adversary_loss) - (alpha * adversary_loss) 

        adversary_optimizer.zero_grad()
        adversary_loss.backward(retain_graph=True)

        predictor_optimizer.zero_grad()
        combined_loss.backward()

        adversary_optimizer.step()
        predictor_optimizer.step()

        predictor_running_loss += predictor_loss.item()
        adversary_running_loss += adversary_loss.item()

    pred_loss = predictor_running_loss/(len(dataloader)*64)
    adv_loss = adversary_running_loss/(len(dataloader)*64)

    print(f'Epoch {epoch+1}, Predictor Loss: {pred_loss:.4f}, Adversary Loss: {adv_loss:.4f}')


Epoch 1, Predictor Loss: 0.0077, Adversary Loss: 0.0116
Epoch 2, Predictor Loss: 0.0053, Adversary Loss: 0.0096
Epoch 3, Predictor Loss: 0.0047, Adversary Loss: 0.0091
Epoch 4, Predictor Loss: 0.0044, Adversary Loss: 0.0088
Epoch 5, Predictor Loss: 0.0042, Adversary Loss: 0.0086
Epoch 6, Predictor Loss: 0.0041, Adversary Loss: 0.0084
Epoch 7, Predictor Loss: 0.0040, Adversary Loss: 0.0083
Epoch 8, Predictor Loss: 0.0039, Adversary Loss: 0.0083
Epoch 9, Predictor Loss: 0.0038, Adversary Loss: 0.0081
Epoch 10, Predictor Loss: 0.0038, Adversary Loss: 0.0082
Epoch 11, Predictor Loss: 0.0037, Adversary Loss: 0.0081
Epoch 12, Predictor Loss: 0.0036, Adversary Loss: 0.0079
Epoch 13, Predictor Loss: 0.0037, Adversary Loss: 0.0079
Epoch 14, Predictor Loss: 0.0036, Adversary Loss: 0.0081
Epoch 15, Predictor Loss: 0.0037, Adversary Loss: 0.0080
Epoch 16, Predictor Loss: 0.0036, Adversary Loss: 0.0079
Epoch 17, Predictor Loss: 0.0036, Adversary Loss: 0.0079
Epoch 18, Predictor Loss: 0.0036, Advers

In [22]:
checkpoint_path = '/kaggle/working/adversarial_classifier_checkpoint.pth'

In [24]:
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': predictor_optimizer.state_dict(),
    'loss': pred_loss,
}, checkpoint_path)
