In [None]:
print("Hello World!")

In [1]:
import csv
import numpy as np
import random
import torch
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [8]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.CNN = nn.Sequential(
            nn.Conv2d(3, 5, 5),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(5, 10, 5),
            nn.Linear(10 * 5 * 5, 32),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.CNN(x)
        return x

In [10]:
def train_discriminator(model, train_data, val_data, num_epochs=5, learning_rate=0.001, batch_size=64):
    """ Training loop. You should update this."""
    torch.manual_seed(1)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)

    training_accuracies = np.zeros(num_epochs)
    validation_accuracies = np.zeros(num_epochs)

    for epoch in range(num_epochs):
        for data in train_loader:
            data = data.cuda()
            inputs, labels = data
            out = Discriminator(inputs)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
          
        training_accuracies[epoch] = (get_accuracy_discriminator(model, train_loader, batch_size))
        val_accuracy = get_accuracy_discriminator(model, val_loader, batch_size)
        validation_accuracies[epoch] = val_accuracy
        print("Iteration", epoch, "Accuracy = ", val_accuracy)
    
    plt.plot(range(1, num_epochs+1), training_accuracies, label="Train")
    plt.plot(range(1, num_epochs+1), validation_accuracies, label="Validation")
    plt.show()

def get_accuracy_discriminator(model, loader, batch_size):
  correct = 0
  all = 0
  for data in loader:
    data = data.cuda()
    inputs, labels = data
    out = Discriminator(inputs)
    for unit in range(batch_size):
      all += 1
      if labels[unit] == 1 and out[unit] >= 0.5:
        correct += 1
      elif labels[unit] == 0 and out[unit] <= 0.5:
        correct += 1
  return correct / all