In [1]:
%cd ..


/pub/hofmann-scratch/glanzillo/ded


In [2]:

import importlib
import json
import math
import os
import socket
import sys
import time


internal_path = os.path.abspath(os.path.join('.'))
sys.path.append(internal_path)
sys.path.append(internal_path + '/datasets')
sys.path.append(internal_path + '/utils')

import datetime
import uuid
from argparse import ArgumentParser

import setproctitle
import torch
import numpy as np
import pandas as pd 
import json

from copy import deepcopy

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import AutoMinorLocator
from matplotlib.collections import LineCollection



import shutil
from utils.args import add_management_args, add_rehearsal_args
from utils.conf import set_random_seed, get_device, base_path
from utils.status import ProgressBar
from utils.stil_losses import *
from utils.nets import *
from torchvision.datasets import ImageNet, ImageFolder
from torch.utils.data import ConcatDataset
from torchvision.models import efficientnet_v2_s, resnet50, ResNet50_Weights, googlenet, efficientnet_b0, mobilenet_v3_large
from utils.eval import evaluate, validation_and_agreement, distance_models, validation_agreement_function_distance
from dataset_utils.data_utils import load_dataset, CIFAR100sparse2coarse
from torch.distributions import Categorical




## Teacher entropy calculation

maximum possible entropy = $$-log(1/C) = log(C)$$

In [3]:
def compute_entropy(data_loader, model, device, C, temperature=1):
    # running estimate of the outer products and mean
    entropy = 0; labels_entropy=0; total=0
    T = temperature
    progress_bar = ProgressBar(verbose=True)

    for i, data in enumerate(data_loader):
        #if i==10: break # for testing
        with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                probabilities = torch.nn.functional.softmax(outputs/T, dim=1)
                entropy += Categorical(probs = probabilities).entropy().sum()
                labels_entropy += Categorical(probs=F.one_hot(labels, num_classes=C)).entropy().sum()
                total += labels.shape[0]
                
        progress_bar.prog(i, len(data_loader), -1, 'Computing entropy', i/(len(data_loader)))  
        
    
    return (entropy/total).detach().cpu().numpy(), (labels_entropy/total).detach().cpu().numpy()

In [4]:
GPUID = 0
os.environ["CUDA_VISIBLE_DEVICES"]=str(GPUID)
device = get_device([GPUID])

### IMAGENET

maximum entropy = $6.9$

In [5]:
NUM_SAMPLES = 500000

In [6]:
imagenet_root = "/local/home/stuff/imagenet/"
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

inference_transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ])

train_dataset = ImageFolder(imagenet_root+'train', inference_transform)
all_indices = set(range(len(train_dataset)))
random_indices = np.random.choice(list(all_indices), size=NUM_SAMPLES, replace=False)
data = Subset(train_dataset, random_indices)
loader = DataLoader(data, batch_size=32, shuffle=True, num_workers=4, pin_memory=False)

# val_dataset = ImageFolder(imagenet_root+'val', inference_transform)
# all_data = ConcatDataset([train_dataset, val_dataset])
# all_data_loader = DataLoader(all_data, batch_size=64, shuffle=True, num_workers=4, pin_memory=False)

In [7]:
def load_checkpoint(best=False, filename='checkpoint.pth.tar', student=False):
    path = base_path() + "chkpts" + "/" + "imagenet" + "/" + "resnet50/"
    if best: filepath = path + 'model_best.pth.tar'
    else: filepath = path + filename
    if os.path.exists(filepath):
          print(f"Loading existing checkpoint {filepath}")
          checkpoint = torch.load(filepath)  
          if not student: 
              new_state_dict = {k.replace('module.','',1):v for (k,v) in checkpoint['state_dict'].items()}
              checkpoint['state_dict'] = new_state_dict
          return checkpoint
    return None 

In [8]:

# initialising the model
teacher =  resnet50(weights=None)

CHKPT_NAME = 'rn50_2023-02-21_10-45-30_best.ckpt'
checkpoint = load_checkpoint(best=False, filename=CHKPT_NAME, student=False) 
if checkpoint: 
        teacher.load_state_dict(checkpoint['state_dict'])
        teacher.to(device)

Loading existing checkpoint ./logs/chkpts/imagenet/resnet50/rn50_2023-02-21_10-45-30_best.ckpt


In [9]:
teacher_entropy, label_entropy = compute_entropy(loader, teacher, C=1000, device=device)
teacher_entropy, label_entropy

[ 11-28 | 11:52 ] Task Computing entropy | epoch -1: |██████████████████████████████████████████████████| 5.6 ep/h | loss: 0.999936 ||

(array(0.88050705, dtype=float32), array(1.192093e-07, dtype=float32))

### CIFAR 5M

maximim entropy = $2.3$

In [10]:
NUM_SAMPLES = 500000

In [11]:
C5m_train, C5m_test = load_dataset('cifar5m', augment=False)

Loading CIFAR 5mil...
Loaded part 1/6


KeyboardInterrupt: 

In [None]:


print(f"Randomly drawing {NUM_SAMPLES} samples for the Cifar5M base")
all_indices = set(range(len(C5m_train)))
random_indices = np.random.choice(list(all_indices), size=NUM_SAMPLES, replace=False)
data = Subset(C5m_train, random_indices)
loader = DataLoader(data, batch_size=128, shuffle=True, num_workers=4, pin_memory=False)

Randomly drawing 500000 samples for the Cifar5M base


In [28]:
def load_checkpoint(best=False, filename='checkpoint.pth.tar', type='mnet'):
    """ Available network types: [mnet, convnet]"""
    path = base_path() + "chkpts" + "/" + "cifar5m" + "/" + f"{type}/"
    if best: filepath = path + 'model_best.ckpt'
    else: filepath = path + filename
    if os.path.exists(filepath):
          print(f"Loading existing checkpoint {filepath}")
          checkpoint = torch.load(filepath)
          return checkpoint
    return None 


In [32]:
# ENTROPY = 0.35
teacher = mobilenet_v3_large(num_classes=10) # adjusting for CIFAR 
CHKPT_NAME = f'mnet-teacher.ckpt' 

In [29]:
# ENTROPY = 2.26
teacher = make_cnn(c=20, num_classes=10, use_batch_norm=True)
CHKPT_NAME = f'convnet-teacher.ckpt' 

CNN made with 154250 parameters


In [33]:
checkpoint = load_checkpoint(best=False, filename=CHKPT_NAME, type='mnet') 
if checkpoint: 
        teacher.load_state_dict(checkpoint['state_dict'])
        teacher.to(device)

Loading existing checkpoint ./logs/chkpts/cifar5m/mnet/mnet-teacher.ckpt


In [37]:
teacher_entropy, label_entropy = compute_entropy(loader, teacher, C=10, device=device)
teacher_entropy, label_entropy

[ 11-28 | 11:28 ] Task Computing entropy | epoch -1: |██████████████████████████████████████████████████| 101.68 ep/h | loss: 0.99974405 |

(array(0.355543, dtype=float32), array(1.192093e-07, dtype=float32))

### CIFAR 100

maximum entropy = $4.6$

In [12]:

C100_train, C100_val = load_dataset('cifar100', augment=False)
all_data = ConcatDataset([C100_train, C100_val])
loader = DataLoader(all_data, batch_size=128, shuffle=True, num_workers=4, pin_memory=False)

In [13]:
def load_checkpoint(best=False, filename='checkpoint.pth.tar', distributed=False):
    path = base_path() + "chkpts" + "/" + "cifar100" + "/" + "resnet18/"
    if best: filepath = path + 'model_best.ckpt'
    else: filepath = path + filename
    if os.path.exists(filepath):
          print(f"Loading existing checkpoint {filepath}")
          checkpoint = torch.load(filepath)
          return checkpoint
    return None 

In [14]:
# initialising the model
teacher = resnet18(num_classes=100)
#teacher = make_cnn(c=150, num_classes=100, use_batch_norm=True)
CHKPT_NAME = 'resnet18-teacher.ckpt'
#CHKPT_NAME = 'convnet150-teacher.ckpt'
checkpoint = load_checkpoint(best=False, filename=CHKPT_NAME) 
if checkpoint: 
        teacher.load_state_dict(checkpoint['state_dict'])
        teacher.to(device)

Loading existing checkpoint ./logs/chkpts/cifar100/resnet18/resnet18-teacher.ckpt


In [8]:
teacher_entropy, label_entropy = compute_entropy(loader, teacher, C=100, device=device)
teacher_entropy, label_entropy

[ 11-28 | 11:30 ] Task Computing entropy | epoch -1: |██████████████████████████████████████████████████| 735.67 ep/h | loss: 0.9978678 ||

(array(0.44312614, dtype=float32), array(1.192093e-07, dtype=float32))

In [15]:
teacher_entropy, label_entropy = compute_entropy(loader, teacher, C=100, device=device, temperature=10)
teacher_entropy, label_entropy

[ 11-28 | 11:54 ] Task Computing entropy | epoch -1: |██████████████████████████████████████████████████| 751.55 ep/h | loss: 0.9978678 ||

(array(4.5765233, dtype=float32), array(1.192093e-07, dtype=float32))

In [16]:
teacher_entropy, label_entropy = compute_entropy(loader, teacher, C=100, device=device, temperature=0.1)
teacher_entropy, label_entropy

[ 11-28 | 11:54 ] Task Computing entropy | epoch -1: |██████████████████████████████████████████████████| 743.29 ep/h | loss: 0.9978678 ||

(array(0.02093775, dtype=float32), array(1.192093e-07, dtype=float32))