<a href="https://colab.research.google.com/github/YuxuanLiu0622/ECE50024-Project-Team15/blob/main/checkpoint3_reweight_upload.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Checkpoint 3: Reimplementation of Learning to Reweight Examples for Robust Deep Learning on Toy Problems (FashionMNIST) Part 2: CNN with reweighting

**Team15: Hyun Soo Park, Andres Martinez, Heesoo Kim, Mingyu Kim, Yuxuan Liu**

In [1]:
!pip install tqdm
import time
from typing import List, Dict
import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from copy import deepcopy
from tqdm import tqdm
import IPython



In [2]:
!pip install higher

Collecting higher
  Downloading higher-0.2.1-py3-none-any.whl (27 kB)
Installing collected packages: higher
Successfully installed higher-0.2.1


In [3]:
import higher

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def set_seed(seed: int = 0) -> None:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(0)

In [6]:
class CNN(nn.Module):

  def __init__(self):
    super(CNN, self).__init__()

    self.conv = nn.Conv2d(1, 8, kernel_size=3)
    self.conv2 = nn.Conv2d(8, 16, kernel_size=3)
    self.conv3 = nn.Conv2d(16,32,kernel_size=3)
    self.conv4 = nn.Conv2d(32,32,kernel_size=3)
    self.fc = nn.Linear(512, 1)

  def forward(self, x):
    x = self.conv(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.relu(F.max_pool2d(x,2))
    x = self.conv3(x)
    x = F.relu(x)
    x = self.conv4(x)
    x = F.relu(x)
    x = F.relu(F.max_pool2d(x,2))
    x = x.view(x.size(0), -1)
    x = self.fc(x)

    return x

In [7]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                      torchvision.transforms.Normalize((0.1307,),(0.3081,))])

In [8]:
train_fmnist = torchvision.datasets.FashionMNIST(root="data", train=True, download=True,transform=transform)
test_fmnist = torchvision.datasets.FashionMNIST(root="data", train=False, download=True,transform=transform)
val_fmnist = torchvision.datasets.FashionMNIST(root="data", train=False, download=True,transform=transform)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 19456498.98it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 331158.20it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 6064310.83it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 4585321.09it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw






In [9]:
def split_dataset(class1: int, class2: int, imbalance_ratio: float, n_samples: int, train_dataset: torchvision.datasets) -> torch.utils.data.Dataset:
    new_data = train_dataset
    n_class1 = int(imbalance_ratio*n_samples)
    n_class2 = n_samples - n_class1
    class1_indices = (train_dataset.targets == class1).nonzero().squeeze()
    class2_indices = (train_dataset.targets == class2).nonzero().squeeze()

    # Randomly sample indices for each class based on the desired number of samples
    selected_class1_indices = class1_indices[torch.randperm(class1_indices.size(0))[:n_class1]]
    selected_class2_indices = class2_indices[torch.randperm(class2_indices.size(0))[:n_class2]]

    new_data.data = torch.cat((train_dataset.data[selected_class1_indices], train_dataset.data[selected_class2_indices]))

    # Update the targets based on the new number of samples
    new_data.targets = torch.cat((torch.zeros(selected_class1_indices.size()), torch.ones(selected_class2_indices.size())))
    return new_data

The training set is given an imbalanced proportion of 99.5%. The validation and testing set is equally distributed.

In [10]:
train_set = split_dataset(7,9,0.9,8000,train_fmnist)
test_set = split_dataset(7,9,0.5,1000,test_fmnist)
val_set = split_dataset(7,9,0.5,50,val_fmnist)

In [11]:
hyperparameters = {
    'lr' : 1e-3,
    'momentum' : 0.9,
    'batch_size' : 128,
    'epoch' : 5000,
}

In [12]:
train_loader = DataLoader(dataset=train_set, batch_size = hyperparameters['batch_size'], shuffle=True)
test_loader = DataLoader(dataset=test_set, batch_size=hyperparameters['batch_size'], shuffle=False)
val_loader = DataLoader(dataset=val_set, batch_size=hyperparameters['batch_size'], shuffle=False)

In [16]:
model = CNN().to(device)
opt = optim.SGD(model.parameters(), lr=hyperparameters['lr'])
loss_fn = nn.BCEWithLogitsLoss().to(device)

## Apply reweighting method to CNN

In [17]:
for i in tqdm(range(1, hyperparameters['epoch']+1)):
  model.train()
  train_loss, train_acc = 0, 0
  images, labels = next(iter(train_loader))

  images_tr = images.to(device)
  labels_tr = labels.to(device)

  opt.zero_grad()

  with higher.innerloop_ctx(model, opt) as (meta_model, meta_opt):
    meta_train_outputs = meta_model(images_tr).squeeze()
    loss_fn.reduction = 'none'
    meta_train_loss = loss_fn(meta_train_outputs, labels.float())
    eps = torch.zeros(meta_train_loss.size(), requires_grad=True).to(device)
    # construct the computational graph
    meta_train_loss = torch.sum(eps * meta_train_loss)
    meta_opt.step(meta_train_loss)

    images_meta, labels_meta = next(iter(val_loader))
    y_g_hat = meta_model(images_meta).squeeze()

    loss_fn.reduction = 'mean'
    meta_val_loss = loss_fn(y_g_hat, labels_meta.float())
    # take the gradient wrt epsilon
    eps_grads = torch.autograd.grad(meta_val_loss, eps)[0].detach()
  # limit the weight >=0 and normalize it
  w_tilde = torch.clamp(-eps_grads, min=0)
  l1_norm = torch.sum(w_tilde)
  if l1_norm != 0:
      w = w_tilde / l1_norm
  else:
      w = w_tilde

  y_f_hat = model(images).squeeze()
  loss_f_hat = torch.sum(w * loss_fn(y_f_hat, labels.float()))
  loss_f_hat.backward()
  opt.step()

  train_loss += loss_f_hat.item()
  pred_labels = (F.sigmoid(y_f_hat) > 0.5).int()
  train_acc += torch.sum(torch.eq(pred_labels, labels)).item()

  if i % 10 == 0 and i != 0:
      model.eval()
      test_acc = []
      for i, (images_test, labels_test) in enumerate(test_loader):
          images_test = images_test.to(device)
          labels_test = labels_test.to(device)

          y_hat = model(images_test).to(device)
          prediction = (F.sigmoid(y_hat) > 0.5).int()
          test_acc.append((torch.flatten(prediction).int() == labels_test.int()).int())

      acc = torch.mean(torch.cat(test_acc,dim=0).float())
      print(acc)

  0%|          | 10/5000 [00:03<33:03,  2.52it/s]

tensor(0.5000)


  0%|          | 13/5000 [00:04<26:51,  3.09it/s]


KeyboardInterrupt: 