In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [22]:
from torchmetrics.functional import accuracy
from src.utils.common import Accumulator


def topIndices(vec, x, largest):
    sorted_idx = torch.argsort(vec, descending=largest)
    top_x_idxs, other_idxs = (sorted_idx[:x], sorted_idx[x:])
    return top_x_idxs, other_idxs

def irreducibleLoss(data=None, target=None, global_index=None, 
                              irreducible_loss_model=None, target_device=None):
    
    if type(irreducible_loss_model) is torch.Tensor:
        if (target_device is not None and irreducible_loss_model.device != target_device ):
            irreducible_loss_model = irreducible_loss_model.to(device=target_device)
        irreducible_loss = irreducible_loss_model[global_index]
    else:
        irreducible_loss = F.cross_entropy(irreducible_loss_model(data), target, reduction="none")
    return irreducible_loss

class ReducibleLossSelection:
    bald = False

    def __call__(self, selected_batch_size, data=None, target=None, global_index=None, large_model=None,
        irreducible_loss_model=None):

        with torch.no_grad():

            model_loss = F.cross_entropy( large_model(data), target, reduction="none")

            irreducible_loss = irreducibleLoss(data, target, global_index, irreducible_loss_model, model_loss.device)

            reducible_loss = model_loss - irreducible_loss

            top_x_idxs, _ = topIndices( reducible_loss, selected_batch_size, largest=True )
            selected_irreducible_loss = irreducible_loss[top_x_idxs]

        return top_x_idxs, selected_irreducible_loss
    
selection_method = ReducibleLossSelection()
update_irreducible = True

def train_irr_loss(
        irreducible_loss_clf, 
        data_loader,
        loss_fn, 
        opt, 
        accumulator : Accumulator = None , 
        selection__mode = True, presample = 3.0 ):
    
    # BEGIN Solution (do not delete this comment!)
    ret_loss, correct = (0.0, 0)
    irreducible_loss_clf.train(True)  # Set model to training mode
    
    for x, labels in data_loader:
        x, labels = x.cuda(), labels.cuda()
        opt.zero_grad()
        all_pred = irreducible_loss_clf(x)
        loss = loss_fn(all_pred, labels).mean()

        loss.backward()
        opt.step()
        
        ret_loss += loss.item()
        #pred = all_pred.max(1)[1]

        with torch.no_grad():
            #correct += (pred==labels).sum().item()
            correct += (all_pred.argmax(dim=1) == labels).sum().cpu().item()

    ret_loss = ret_loss / len(data_loader)
    accuracy = correct / len(data_loader.dataset)


    return {"ret_loss":ret_loss, "accuracy": accuracy}


def train_batch_rho_loss(
        large_model,
        irreducible_loss_model,
        batch,
        loss_fn, 
        optimizer, 
        accumulator : Accumulator, 
        selection__mode = True, presample = 3.0 ):


    global_index, data, target = batch
    batch_size = len(data)
    selected_batch_size = max(1, int(batch_size /presample))


    large_model.eval() 

    selected_indices,  irreducible_loss = selection_method.__call__(
        selected_batch_size=selected_batch_size,
        data=data,
        target=target,
        global_index=global_index,
        large_model=large_model,
        irreducible_loss_model=irreducible_loss_model,
    ) 

    large_model.train()  # switch to eval mode to compute selection

    data, target = data[selected_indices], target[selected_indices]

    optimizer.zero_grad()
    logits = large_model(data)
    loss = loss_fn(logits, target) 

    mean_loss = loss.mean()
    mean_loss.backward()
    optimizer.step()

    # training metrics
    #preds = torch.argmax(F.log_softmax(logits, dim=1), dim=1)

    n = len(logits)
    acc = (logits.argmax(dim=1) == target).sum().cpu().item()/n

    accumulator.average( 
        train_loss = ( mean_loss.cpu().item(), n) ,
        train_acc = ( acc, n) )


In [9]:
import copy

def train_full( model, model_irr, train_dataloader, train_irr_loader, loss_fn, optimizer,optimizer_irr,
                n_epochs, eval = None, callback=None, presample=3, tau_th = None):
    #model_irr = copy.deepcopy(model)


    large_batch = int( train_dataloader.batch_size)
    
    if callback :
        callback.setMeta(
            large_batch = large_batch,
            n_epochs = n_epochs, 
            presample = presample, 
            tau_th = tau_th)
        

    try:
        model.load_state_dict(torch.load("irr_model"))
        model.eval()
        print("irr model loades")
    except:
        epochs = tqdm(range(n_epochs), desc='Irr epochs', leave=True)
        for i_epoch in epochs:
            dict_ = train_irr_loss(model_irr, train_irr_loader, loss_fn, optimizer_irr)
            epochs.set_postfix(dict_)

    torch.save(model_irr.state_dict(), "irr_model")
    
    d = model.device
    epochs = tqdm(range(n_epochs), desc='Epochs', leave=True)
    for i_epoch in epochs:
        accum = Accumulator()
        
        
        for idxs, X_batch, y_batch in train_dataloader:
            train_batch_rho_loss(model, 
                                model_irr,
                                (idxs, X_batch.to(d), y_batch.to(d)),
                                loss_fn,
                                optimizer, 
                                accum,
                                presample)

        if callback :
            val_scores = eval(model) if eval else {}
            cb_dict = callback( **accum.getAll(), **val_scores)
            epochs.set_postfix(cb_dict)

In [29]:
def evaluate(model, dataloader, loss_fn):
    model.eval()
    logits = []
    targets = []
    with torch.no_grad():
        for  X_batch, y_batch in dataloader:
            output = model(X_batch.to(model.device)).cpu()
            logits.append(output)
            targets.append(y_batch)
    logits = torch.cat(logits)
    targets = torch.cat(targets)
    loss = loss_fn(logits, targets).mean().item()
    acc = (logits.argmax(dim=1) == targets).sum().item() / len(targets)
    return loss, acc


In [1]:
from src.utils.data_loaders import train_dataloader, test_dataloader, train_val_dataloader

train_irr_loader  = train_dataloader(batch_size=120, subset=0.25)#128*3
train_loader, test_loader = train_val_dataloader(batch_size=120, index=True)
#test_loader = test_dataloader(batch_size=120)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [12]:
%load_ext autoreload

%aimport common_utils
%autoreload 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
%autoreload 1
from src.models import ResNet50
from src.utils.common import UnCallBack

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


model = ResNet50()
model.to(device)
model.device = device

model_irr = ResNet50()
model_irr.to(device)
model_irr.device = device



loss_fn = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3 )
optimizer_irr = torch.optim.Adam(model_irr.parameters(), lr = 1e-3 )

callback = UnCallBack( info_list = ['train_loss', 'train_acc', 'train_w_loss', 'val_loss', 'val_acc', 'train_uniform_cnt'])

def eval_callback(model):
    loss, acc =evaluate(model, test_loader, loss_fn)
    return {"val_loss": loss, "val_acc": acc}

#print(len(train_loader))
#print(len(test_loader))
#print(len(train_irr_loader))

train_full(model, 
           model_irr,
           train_loader, 
           train_irr_loader,
           loss_fn, 
           optimizer, 
           optimizer_irr,
           n_epochs=50, 
           eval=eval_callback, 
           callback=callback, 
           presample=3, 
           tau_th = None)

callback.save("rho_loss")

cuda
irr model loades


Epochs:   0%|          | 0/50 [02:20<?, ?it/s]


ValueError: not enough values to unpack (expected 3, got 2)