In [1]:
import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler

import torchvision
import torchvision.transforms as T

# PAWS TRAINING

### STL10 Dataset

In [17]:
num_classes = 10
unlabeled_batch_size = 128
support_samples_per_class = 5
num_epochs = 50

In [18]:
unlabeled_data = torchvision.datasets.STL10('./content',split='unlabeled',download=True,transform=T.ToTensor())
labeled_data = torchvision.datasets.STL10('./content',split='train',download=True,transform=T.ToTensor())
test_data = torchvision.datasets.STL10('./content',split='test',download=True,transform=T.ToTensor())

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Class balanced sampling of support samples

In [19]:
# Reference - https://discuss.pytorch.org/t/load-the-same-number-of-data-per-class/65198
class BalancedBatchSampler(BatchSampler):
    def __init__(self, dataset, n_classes, n_samples):
        loader = DataLoader(dataset) # BS = 1 by default
        self.labels_list = []
        for _, label in loader:
            self.labels_list.append(label) # extract all labels in the dataset
        self.labels = torch.LongTensor(self.labels_list) # dtype of LongTensor = int64
        self.labels_set = list(set(self.labels.numpy())) # converts tensor(self.label) to numpy.... set() extracts all unique elements in a np array and then it is converted to a list.
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set} # a dictionary of indices (images) for each class label (0,1,2..)
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l]) # shuffle all indices
        self.used_label_indices_count = {label: 0 for label in self.labels_set} # makes sure they donot overlap upon iter.
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < len(self.dataset):
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False) # choose unique labels randomly
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][ # extend is similar to append operation. for every class_label indices of length n_samples is extracted.
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): # reset upon depletion
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices # simlar to return but it is of the type object:generator
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return len(self.dataset) // self.batch_size

In [20]:
balanced_batch_sampler = BalancedBatchSampler(labeled_data, num_classes, support_samples_per_class)
support_dataloader = torch.utils.data.DataLoader(labeled_data, batch_sampler=balanced_batch_sampler)
unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_data, batch_size=unlabeled_batch_size, shuffle=True)

# use the following format for loading support samples alongside unlabeled samples
# support_loader = iter(support_dataloader)
# x,y = next(support_loader)

DATA AUGMENTATION



In [21]:
class DataAugmentation:
  def __init__(self, img_size=96, crop_scale=(0.3,1.0), s=1): # s - jitter strenth for unsupervised

    jitter1 = T.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)

    self.transforms1 = torch.nn.Sequential(
        T.RandomResizedCrop(size=img_size, scale=crop_scale), 
        T.RandomHorizontalFlip(),
        T.RandomApply([jitter1], p=0.8),
        T.Normalize(mean=[0.49139968, 0.48215827 ,0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
    )
    self.transforms2 = torch.nn.Sequential( 
        T.RandomHorizontalFlip(),
        T.Normalize(mean=[0.49139968, 0.48215827 ,0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
    )  
  def __call__(self, x, sx):
    return (self.transforms1(x),self.transforms1(x), self.transforms2(sx))

### MODEL and PAWS loss function

In [22]:
class PAWS_ResNet18(nn.Module):
    def __init__(self,):
        super(PAWS_ResNet18, self).__init__()

        resnet18 = torchvision.models.resnet18(pretrained=False)
        resnet18.fc = nn.Linear(resnet18.fc.in_features,128,bias=True)
        self.base_model = resnet18

        self.projection = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=256),
        )

        self.augmentation = DataAugmentation()

    def forward(self,x, support_x):
      x1,x2, sx = self.augmentation(x, support_x)
      feature1 = self.projection(self.base_model(x1))
      feature2 = self.projection(self.base_model(x2))
      support = self.projection(self.base_model(sx))

      return (feature1,feature2,support)

model = PAWS_ResNet18()

In [23]:
class PAWS_loss(nn.Module):
    def __init__(self, tau=0.1, T=0.25):
        super(PAWS_loss, self).__init__()
        self.tau = tau
        self.T = T
        self.softmax = torch.nn.Softmax(dim=1)

    def sharpen(self, p):
        sharp_p = p**(1./self.T)
        sharp_p /= torch.sum(sharp_p, dim=1, keepdim=True)
        return sharp_p
    
    def snn(self, query, supports, labels):
        # Step 1: normalize embeddings
        query = torch.nn.functional.normalize(query)
        supports = torch.nn.functional.normalize(supports)
        # Step 2: compute similarlity between local embeddings
        return self.softmax(query @ supports.T / self.tau) @ labels
    
    def forward(self, anchor_views,
        anchor_supports,
        anchor_support_labels,
        target_views,
        target_supports,
        target_support_labels):
      
        # Step 1: compute anchor predictions
        probs = self.snn(anchor_views, anchor_supports, anchor_support_labels)

        # Step 2: compute targets for anchor predictions
        with torch.no_grad():
            targets = self.snn(target_views, target_supports, target_support_labels) # Note: target views and target supports must be detached.
            targets = self.sharpen(targets)

        # Step 3: compute cross-entropy loss H(targets, queries)
        loss = torch.mean(torch.sum(torch.log(probs**(-targets)), dim=1))

        # Step 4: compute me-max regularizer
        rloss = 0.
        avg_probs = torch.mean(self.sharpen(probs), dim=0)
        rloss -= torch.sum(torch.log(avg_probs**(-avg_probs)))

        return loss, rloss

### Train Loop

In [24]:
loss_fn = PAWS_loss()
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [25]:
model.to(device)

for epoch in range(num_epochs):

  # Reset Loss Values 
  cross_ent_loss = 0.0
  memax_loss = 0.0

  model.train()
  
  support_loader = iter(support_dataloader) # for every epoch, prepare the support set

  # Load unlabeled data
  for X,_ in unlabeled_dataloader:

    # load support samples
    try:
      sx,sy = next(support_loader) # load support samples
    except StopIteration:
      support_loader = iter(support_dataloader) # reset when support samples are deleted
      sx,sy = next(support_loader)

    # load tensors to device
    X = X.to(device)
    sy = nn.functional.one_hot(sy,num_classes=num_classes).float() # convert longint to float
    sx = sx.to(device)
    sy = sy.to(device)

    #Reseting Gradients
    optimizer.zero_grad()

    # Predictions
    anchor, target, supports = model(X,sx)

    # Calculate Loss
    ent, memax = loss_fn(anchor, supports, sy, target.detach(), supports.detach(), sy)

    cross_ent_loss += ent.item()
    memax_loss += memax.item()
    _loss = ent + memax

    # Update Parameters
    _loss.backward()
    optimizer.step()

  print("\nEpoch..........................................", epoch + 1)
  print("\tCross Entropy Loss = ", round(cross_ent_loss/float(unlabeled_batch_size),2))
  print("\tMean Entropy = ", round(memax_loss/-float(unlabeled_batch_size),2))

torch.save(model.base_model.state_dict(),'paws_weights.pth')


Epoch.......................................... 1
	Cross Entropy Loss =  10.1
	Mean Entropy =  13.81

Epoch.......................................... 2
	Cross Entropy Loss =  9.34
	Mean Entropy =  13.86

Epoch.......................................... 3
	Cross Entropy Loss =  9.09
	Mean Entropy =  13.86

Epoch.......................................... 4
	Cross Entropy Loss =  8.89
	Mean Entropy =  13.87

Epoch.......................................... 5
	Cross Entropy Loss =  8.74
	Mean Entropy =  13.87

Epoch.......................................... 6
	Cross Entropy Loss =  8.44
	Mean Entropy =  13.88

Epoch.......................................... 7
	Cross Entropy Loss =  8.56
	Mean Entropy =  13.88

Epoch.......................................... 8
	Cross Entropy Loss =  8.44
	Mean Entropy =  13.88

Epoch.......................................... 9
	Cross Entropy Loss =  8.28
	Mean Entropy =  13.89

Epoch.......................................... 10
	Cross Entropy Loss =  8.08
	M