In [25]:
import os
import yaml

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
from attributedataset.datasetutils import get_dataloader

In [33]:
yml = 'configs/pascal_config.yml'
with open(yml, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [34]:
train_dataloader, test_dataloader, num_classes = get_dataloader(config)

In [35]:
smpl = next(iter(train_dataloader))

In [36]:
smpl[0].shape

torch.Size([16, 2048, 14, 14])

In [37]:
smpl[0].dtype

torch.float32

In [29]:
class LargeLossMatters(nn.Module):
    def __init__(self,
                 num_classes,
                 backbone='resnet50',
                 freeze_backbone=False,
                 use_feature=False,
                 mod_schemes='LL-R',
                 delta_rel=0.1):
      super().__init__()

      self.num_classes = num_classes
      self.mod_schemes = mod_schemes
      self.delta_rel = delta_rel / 100
      self.use_feature = use_feature
      self.clean_rate = 1.0

      if use_feature:
        self.backbone = None
      elif backbone == 'resnet50':
        self.backbone = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2]) # (N, 2048, 7, 7)
      else:
        raise NotImplementedError
      
      if not use_feature:
        if freeze_backbone:
          for param in self.backbone.parameters():
            param.requires_grad = False
        else:
          for param in self.backbone.parameters():
            param.requires_grad = True
      
      self.fc = nn.Linear(2048, num_classes)
    
    def forward(self, x):
      if self.backbone is not None:
        features = self.backbone(x) # (N, 2048, 7, 7)
      else:
        features = x

      print(features.shape, type(features), features.dtype)

      # adaptive global average pooling
      features = F.adaptive_avg_pool2d(features, (1)).squeeze(-1).squeeze(-1) # (N, 2048, 1, 1)
      
      logits = self.fc(features) # (N, num_classes)
      return logits
    
    def loss(self, x, labels):
      """
      Args:
        x: (N, num_classes)
        labels: (N, num_classes) 
      """
      preds = self.forward(x)

      batch_size = int(labels.shape[0])
      num_classes = int(labels.shape[1])

      loss_fn = nn.BCEWithLogitsLoss(reduction='none')
      loss_matrix = loss_fn(preds, labels) # (N, num_classes)
      corrected_loss_matrix = loss_fn(preds, torch.logical_not(labels).float()) # (N, num_classes)
      correction_idx = None

      if self.clean_rate == 1.0:
        return loss_matrix.mean(), correction_idx

      if self.mod_schemes == 'LL-R':
        k = math.ceil(batch_size * num_classes * (1-self.clean_rate))
      elif self.mod_schemes == 'LL-Ct':
        k = math.ceil(batch_size * num_classes * (1-self.clean_rate))
      elif self.mod_schemes == 'LL-Cp':
        k = math.ceil(batch_size * num_classes * self.delta_rel)
      else:
        raise NotImplementedError
      
      unobserved_loss = (labels == 0).bool() * loss_matrix # (N)
      try:
        topk = torch.topk(unobserved_loss.flatten(), k)
      except:
        print(batch_size)
        print(num_classes)
        print(k)
        raise NotImplementedError('topk error')
      topk_lossval = topk.values[-1]
      correction_idx = torch.where(unobserved_loss > topk_lossval)

      if self.mod_schemes == 'LL-R':
        corrected_loss_matrix = torch.zeros_like(loss_matrix)

      loss_matrix = torch.where(unobserved_loss < topk_lossval, loss_matrix, corrected_loss_matrix)

      return loss_matrix.mean(), correction_idx
    
    def get_cam(self, x):
      features = self.backbone(x) # (N, 2048, 7, 7)
      CAM = F.conv2d(features, self.fc.weight.unsqueeze(-1).unsqueeze(-1)) # (N, num_classes, 1, 1)
      return CAM
    
    def unfreeze_backbone(self):
      for param in self.backbone.parameters():
          param.requires_grad = True

    def forward_linear(self, x):
      x = self.fc(x)
      return x
    
    def decrease_clean_rate(self):
      self.clean_rate = self.clean_rate - self.delta_rel

In [30]:
batch = next(iter(train_dataloader))

In [31]:
model = LargeLossMatters(num_classes, backbone='resnet50', freeze_backbone=False, use_feature=False, mod_schemes='LL-R', delta_rel=0.1)

In [32]:
X, y = batch
logits = model(X)

torch.Size([16, 2048, 14, 14]) <class 'torch.Tensor'> torch.float32
