In [1]:
import json
import os
import argparse

import numpy as np
import random
import torch
from torch import nn, optim
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from pathlib import Path
from torch.utils.data import DataLoader, Subset
import sys
sys.path.append('../../lib/exlib/src')
from exlib.modules.sop import SOPConfig, get_chained_attr

In [2]:
from exlib.modules.sop import SOP
import copy

class SOPText(SOP):
    def __init__(self, 
                 config,
                 blackbox_model,
                 class_weights=None,
                 projection_layer=None,
                 ):
        super().__init__(config,
                            blackbox_model,
                            class_weights
                            )

        if projection_layer is not None:
            self.projection = copy.deepcopy(projection_layer)
        else:
            self.init_projection()

        # Initialize the weights of the model
        self.init_grads()

    def init_projection(self):
        self.projection = nn.Linear(1, self.hidden_size)

    def forward(self, 
                inputs, 
                segs=None, 
                input_mask_weights=None,
                epoch=-1, 
                mask_batch_size=16,
                label=None,
                return_tuple=False,
                kwargs={}):
        if epoch == -1:
            epoch = self.num_heads
        bsz, seq_len = inputs.shape
        
        # Mask (Group) generation
        if input_mask_weights is None:
            grouped_inputs_embeds, input_mask_weights, grouped_kwargs = self.group_generate(inputs, epoch, mask_batch_size, 
                                                                                            segs, kwargs)
            grouped_inputs = None
        else:
            grouped_inputs = inputs.unsqueeze(1) * input_mask_weights.unsqueeze(2) # directly apply mask
        
        # Backbone model
        logits, pooler_outputs = self.run_backbone(grouped_inputs, mask_batch_size, kwargs=grouped_kwargs)

        # Mask (Group) selection & aggregation
        weighted_logits, output_mask_weights, logits, pooler_outputs = self.group_select(logits, pooler_outputs, seq_len)

        if return_tuple:
            return self.get_results_tuple(weighted_logits, logits, pooler_outputs, input_mask_weights, output_mask_weights, bsz, label)
        else:
            return weighted_logits

    def get_results_tuple(self, weighted_logits, logits, pooler_outputs, input_mask_weights, output_mask_weights, bsz, label):
        raise NotImplementedError

    def run_backbone(self, masked_inputs=None, mask_batch_size=16, kwargs={}):  # TODO: Fix so that we don't need to know the input
        if masked_inputs is not None:
            bsz, num_masks, seq_len = masked_inputs.shape
            masked_inputs = masked_inputs.reshape(-1, seq_len)
            kwargs_flat = {k: v.reshape(-1, seq_len) for k, v in kwargs.items()}
        else:
            bsz, num_masks, seq_len, hidden_size = kwargs['inputs_embeds'].shape
            
            kwargs_flat = {k: v.reshape(-1, seq_len, hidden_size) if k == 'inputs_embeds' else v.reshape(-1, seq_len)
                           for k, v in kwargs.items()}
        logits = []
        pooler_outputs = []
        for i in range(0, bsz * num_masks, mask_batch_size):
            kwargs_i = {k: v[i:i+mask_batch_size] for k, v in kwargs_flat.items()}
            output_i = self.blackbox_model(
                masked_inputs[i:i+mask_batch_size] if masked_inputs is not None else None,
                **kwargs_i
            )
            pooler_i = output_i.pooler_output
            logits_i = output_i.logits
            logits.append(logits_i)
            pooler_outputs.append(pooler_i)

        logits = torch.cat(logits).view(bsz, num_masks, self.num_classes, -1)
        pooler_outputs = torch.cat(pooler_outputs).view(bsz, num_masks, self.hidden_size, -1)
        return logits, pooler_outputs
    
    def group_generate(self, inputs, epoch, mask_batch_size=16, segs=None, kwargs={}):
        bsz, seq_len = inputs.shape
        mask_embed = self.projection(torch.tensor([0]).int().to(inputs.device))
        projected_inputs = self.projection(inputs)
        
        if segs is None:   # word level
            projected_inputs = projected_inputs * self.projected_input_scale

            if self.num_masks_max != -1:
                input_dropout_idxs = torch.randperm(projected_inputs.shape[1])
                if 'attention_mask' in kwargs:
                    attention_mask_mult = kwargs['attention_mask'] * input_dropout_idxs
                else:
                    attention_mask_mult = input_dropout_idxs
                input_dropout_idxs = torch.argsort(attention_mask_mult, dim=-1).flip(-1)[:, :self.num_masks_max]
                batch_indices = torch.arange(bsz).unsqueeze(1).repeat(1, input_dropout_idxs.shape[-1])
                selected_projected_inputs = projected_inputs[batch_indices, input_dropout_idxs]
                projected_query = selected_projected_inputs
            else:
                projected_query = projected_inputs
            input_mask_weights_cand = self.input_attn(projected_query, projected_inputs, epoch=epoch)
            input_mask_weights_cand = input_mask_weights_cand.squeeze(1)

            input_mask_weights_cand = torch.clip(input_mask_weights_cand, max=1.0)
        else: # sentence level
            # With/without masks are a bit different. Should we make them the same? Need to experiment.
            bsz, num_segs, seq_len = segs.shape

            seged_inputs_embeds = projected_inputs.unsqueeze(1) * segs.unsqueeze(-1) + \
                               mask_embed.view(1,1,1,-1) * (1 - segs.unsqueeze(-1))
            
            seged_kwargs = {}
            for k, v in kwargs.items():
                seged_kwargs[k] = v.unsqueeze(1).expand(segs.shape).reshape(-1, seq_len)
            seged_kwargs['inputs_embeds'] = seged_inputs_embeds

            # TODO: always have seg for the part after sep token
            _, interm_outputs = self.run_backbone(None, mask_batch_size, kwargs=seged_kwargs)
            
            interm_outputs = interm_outputs.view(bsz, -1, self.hidden_size)
            interm_outputs = interm_outputs * self.projected_input_scale
            segment_mask_weights = self.input_attn(interm_outputs, interm_outputs, epoch=epoch)
            segment_mask_weights = segment_mask_weights.reshape(bsz, -1, num_segs)
            
            new_masks =  segs.unsqueeze(1) * segment_mask_weights.unsqueeze(-1)
            # (bsz, num_new_masks, num_masks, seq_len)
            input_mask_weights_cand = new_masks.sum(2)  # if one mask has it, then have it
            # todo: Can we simplify the above to be dot product?
            
        scale_factor = 1.0 / input_mask_weights_cand.max(dim=-1).values
        input_mask_weights_cand = input_mask_weights_cand * scale_factor.view(bsz, -1,1)

        # we are using iterative training
        # we will train some masks every epoch
        # the masks to train are selected by mod of epoch number
        # Dropout for training
        if self.training:
            dropout_idxs = torch.randperm(input_mask_weights_cand.shape[1])[:self.num_masks_sample]
            dropout_mask = torch.zeros(bsz, input_mask_weights_cand.shape[1]).to(inputs.device)
            dropout_mask[:,dropout_idxs] = 1
        else:
            dropout_mask = torch.ones(bsz, input_mask_weights_cand.shape[1]).to(inputs.device)
        
        input_mask_weights = input_mask_weights_cand[dropout_mask.bool()].clone()
        input_mask_weights = input_mask_weights.reshape(bsz, -1, seq_len)
        
        # Always add the second part of the sequence (in question answering, it would be the qa pair)
        input_mask_weights = input_mask_weights  + kwargs['token_type_ids'].unsqueeze(1)
        
        masked_inputs_embeds = projected_inputs.unsqueeze(1) * input_mask_weights.unsqueeze(-1) + \
                               mask_embed.view(1,1,1,-1) * (1 - input_mask_weights.unsqueeze(-1))
        
        masked_kwargs = {}
        for k, v in kwargs.items():
            masked_kwargs[k] = v.unsqueeze(1).expand(input_mask_weights.shape).reshape(-1, seq_len)
        masked_kwargs['inputs_embeds'] = masked_inputs_embeds
        
        return masked_inputs_embeds, input_mask_weights, masked_kwargs
    
    def group_select(self, logits, pooler_outputs, seq_len):
        raise NotImplementedError


class SOPTextCls(SOPText):
    def group_select(self, logits, pooler_outputs, seq_len):
        bsz, num_masks = logits.shape[:2]

        logits = logits.view(bsz, num_masks, self.num_classes)
        pooler_outputs = pooler_outputs.view(bsz, num_masks, self.hidden_size)

        query = self.class_weights.unsqueeze(0).expand(bsz, 
                                                    self.num_classes, 
                                                    self.hidden_size) #.to(logits.device)
        
        key = pooler_outputs
        weighted_logits, output_mask_weights = self.output_attn(query, key, logits)

        return weighted_logits, output_mask_weights, logits, pooler_outputs
    
    def get_results_tuple(self, weighted_logits, logits, pooler_outputs, input_mask_weights, output_mask_weights, bsz, label):
        # todo: debug for segmentation
        masks_aggr = None
        masks_aggr_pred_cls = None
        masks_max_pred_cls = None
        flat_masks = None

        if label is not None:
            predicted = label  # allow labels to be different
        else:
            _, predicted = torch.max(weighted_logits.data, -1)
        
        masks_mult = input_mask_weights.unsqueeze(2) * \
        output_mask_weights.unsqueeze(-1).unsqueeze(-1) # bsz, n_masks, n_cls
        
        masks_aggr = masks_mult.sum(1) # bsz, n_cls
        masks_aggr_pred_cls = masks_aggr[range(bsz), predicted].unsqueeze(1)
        max_mask_indices = output_mask_weights.max(2).values.max(1).indices
        masks_max_pred_cls = masks_mult[range(bsz),max_mask_indices,predicted].unsqueeze(1)
        flat_masks = compress_masks_image(input_mask_weights, output_mask_weights)
        return AttributionOutputSOP(weighted_logits,
                                    logits,
                                    pooler_outputs,
                                    input_mask_weights,
                                    output_mask_weights,
                                    masks_aggr_pred_cls,
                                    masks_max_pred_cls,
                                    masks_aggr,
                                    flat_masks)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 42
if SEED != -1:
    # Torch RNG
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    # Python RNG
    np.random.seed(SEED)
    random.seed(SEED)

In [4]:
# model paths
backbone_model_name = '../../pt_models/multirc_vanilla/best'
backbone_processor_name = 'bert-base-uncased'
# sop_config_path = 'configs/imagenet_m.json'

# data paths
# TRAIN_DATA_DIR = '../data/imagenet_m/train'
# VAL_DATA_DIR = '../data/imagenet_m/val'

# training args
batch_size = 2
lr = 0.0000005
num_epochs = 20
warmup_steps = 50
mask_batch_size = 4

# experiment args
exp_dir = '../../exps/multirc'
os.makedirs(exp_dir, exist_ok=True)

In [5]:
backbone_model = AutoModelForSequenceClassification.from_pretrained(backbone_model_name)
processor = AutoTokenizer.from_pretrained(backbone_processor_name)
backbone_config = AutoConfig.from_pretrained(backbone_model_name)

config = SOPConfig(
    # attn_patch_size=16,
    num_heads=1,
    num_masks_sample=8,
    num_masks_max=16,
    finetune_layers=['model.classifier']
)
config.__dict__.update(backbone_config.__dict__)
config.num_labels = len(backbone_config.label2id)
# config.save_pretrained(exp_dir)

In [6]:
from torch.utils.data import DataLoader
from datasets import load_dataset

SENT_SEPS = [processor.convert_tokens_to_ids(processor.tokenize(token)[0]) for token in [';',',','.','?','!',';']]
SEP = processor.convert_tokens_to_ids(processor.tokenize('[SEP]')[0])
print('SEP', SEP, 'SENT_SEPS', SENT_SEPS)

def sent_seg(input_ids):
    segs = []
    count = 1
    for i, input_id in enumerate(input_ids):
        if count in [0, -1]:
            if input_id == SEP:
                count = -1
            segs.append(count)
            continue
        else:
            if input_id in SENT_SEPS:
                segs.append(count)
                count += 1
            elif input_id == SEP:
                if count > 0:
                    count = 0
                    segs.append(count)
                else:
                    segs.append(count)
                    count = -1
            else: # normal character
                segs.append(count)
    return segs

def convert_idx_masks_to_bool_text(masks):
    """
    input: masks (1, seq_len)
    output: masks_bool (num_masks, seq_len)
    """
    unique_idxs = torch.sort(torch.unique(masks)).values
    unique_idxs = unique_idxs[unique_idxs != -1]
    unique_idxs = unique_idxs[unique_idxs != 0]
    idxs = unique_idxs.view(-1, 1)
    broadcasted_masks = masks.expand(unique_idxs.shape[0], 
                                     masks.shape[1])
    masks_bool = (broadcasted_masks == idxs)
    return masks_bool


def get_mask_transform_text(num_masks_max=200, processor=None):
    def mask_transform(mask):
        seg_mask_cut_off = num_masks_max
        # print('mask 1', mask)
        # if mask.max(dim=-1) > seg_mask_cut_off:
        # import pdb; pdb.set_trace()
        if mask.max(dim=-1).values.item() > seg_mask_cut_off:
            mask_new = (mask / (mask.max(dim=-1).values / seg_mask_cut_off)).int().float() + 1
            # bsz, seq_len = mask_new.shape
            # print('mask 2', mask_new)
            # import pdb; pdb.set_trace()
            mask_new[mask == 0] = 0
            mask_new[mask == -1] = -1
            mask = mask_new
        
        if mask.dtype != torch.bool:
            if len(mask.shape) == 1:
                mask = mask.unsqueeze(0)
            # print('mask', mask.shape)
            mask_bool = convert_idx_masks_to_bool_text(mask)
        # print(mask.shape)
        bsz, seq_len = mask.shape
        mask_bool = mask_bool.float()
        
        

        if bsz < seg_mask_cut_off:
            repeat_count = seg_mask_cut_off // bsz + 1
            mask_bool = torch.cat([mask_bool] * repeat_count, dim=0)

        # add additional mask afterwards
        mask_bool_sum = torch.sum(mask_bool[:seg_mask_cut_off - 1], dim=0, keepdim=True).bool()
        if False in mask_bool_sum:
            mask_bool = mask_bool[:seg_mask_cut_off - 1]
            # import pdb; pdb.set_trace()
            compensation_mask = (1 - mask_bool_sum.int()).bool()
            compensation_mask[mask == 0] = False
            compensation_mask[mask == -1] = False
            mask_bool = torch.cat([mask_bool, compensation_mask])
        else:
            mask_bool = mask_bool[:seg_mask_cut_off]
        return mask_bool
    return mask_transform

mask_transform = get_mask_transform_text(config.num_masks_max)

def transform(batch):
    # Preprocess the image using the ViTImageProcessor
    if processor is not None:
        inputs = processor(batch['passage'], 
                           batch['query_and_answer'], 
                           padding='max_length', 
                           truncation=True, 
                           max_length=512)
        segs = [sent_seg(input_id) for input_id in inputs['input_ids']]
        inputs = {k: torch.tensor(v) for k, v in inputs.items()}
        
        segs_bool = []
        for seg in segs:
            seg_bool = mask_transform(torch.tensor(seg))
            segs_bool.append(seg_bool)
        inputs['segs'] = torch.stack(segs_bool)
        # print("inputs['segs']", inputs['segs'].shape)
        # for k, v in inputs.items():
        #     print(k, v.shape)
        # import pdb; pdb.set_trace()
        return inputs
    else:
        return batch


# train_size, val_size = -1, -1
train_size, val_size = 100, 100

train_dataset = load_dataset('eraser_multi_rc', split='train')
train_dataset = train_dataset.map(transform, batched=True,
                            remove_columns=['passage', 
                                            'query_and_answer',
                                            'evidences'])

val_dataset = load_dataset('eraser_multi_rc', split='validation')
val_dataset = val_dataset.map(transform, batched=True,
                            remove_columns=['passage', 
                                            'query_and_answer',
                                            'evidences'])

if train_size != -1:
    train_dataset = Subset(train_dataset, list(range(train_size)))
if val_size != -1:
    val_dataset = Subset(val_dataset, list(range(val_size)))

# Create a DataLoader to batch and shuffle the data
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

SEP 102 SENT_SEPS [1025, 1010, 1012, 1029, 999, 1025]


In [7]:
from collections import namedtuple

WrappedBackboneOutput = namedtuple("WrappedBackboneOutput", 
                                  ["logits",
                                   "pooler_output"])


class WrappedBackboneModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, inputs=None, **kwargs):
        outputs = self.model(inputs, output_hidden_states=True, **kwargs)
        return WrappedBackboneOutput(outputs.logits, outputs.hidden_states[-1][:,0])

In [8]:
wrapped_backbone_model = WrappedBackboneModel(backbone_model)
wrapped_backbone_model = wrapped_backbone_model.to(device)
class_weights = get_chained_attr(wrapped_backbone_model, config.finetune_layers[0]).weight #.clone().to(device)
projection_layer = wrapped_backbone_model.model.bert.embeddings.word_embeddings

In [9]:
model = SOPTextCls(config, wrapped_backbone_model, class_weights=class_weights, projection_layer=projection_layer)
model = model.to(device)

deep copy class weights


In [10]:
from transformers import get_scheduler

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
num_training_steps = len(train_dataloader) * num_epochs
lr_scheduler = get_scheduler(
            'inverse_sqrt',
            optimizer=optimizer, 
            num_warmup_steps=warmup_steps
        )
criterion = nn.CrossEntropyLoss()

In [11]:
def eval(model, dataloader, criterion, sop=True):
    print('Eval ...')
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        progress_bar_eval = tqdm(range(len(dataloader)))
        for i, batch in enumerate(dataloader):
            # Now you can use `inputs` and `labels` in your training loop.
            if not isinstance(batch['input_ids'], torch.Tensor):
                inputs = torch.stack(batch['input_ids']).transpose(0, 1).to(device)
                if 'token_type_ids' in batch:
                    token_type_ids = torch.stack(batch['token_type_ids']).transpose(0, 1).to(device)
                else:
                    token_type_ids = None
                attention_mask = torch.stack(batch['attention_mask']).transpose(0, 1).to(device)

                concatenated_rows = [torch.stack(sublist) for sublist in batch['segs']]
                segs = torch.stack(concatenated_rows).permute(2, 0, 1).to(device).float()
                # print('segs', segs.shape)
            else:
                inputs = batch['input_ids'].to(device)
                if 'token_type_ids' in batch:
                    token_type_ids = batch['token_type_ids'].to(device)
                else:
                    token_type_ids = None
                attention_mask = batch['attention_mask'].to(device)
                segs = batch['segs'].to(device).float()
            kwargs = {
                'token_type_ids': token_type_ids,
                'attention_mask': attention_mask,
            }
            labels = batch['label'].to(device)

            if sop:
                logits = model(inputs, segs=segs, kwargs=kwargs)
            else:
                logits = model(inputs, **kwargs).logits
            
            # val loss
            loss = criterion(logits, labels)
            total_loss += loss.item()
            
            # acc
            _, predicted = torch.max(logits.data, 1)
            correct += (predicted == labels).sum().item()
            
            total += labels.size(0)
            
            progress_bar_eval.update(1)
    
    val_acc = correct / total
    val_loss = total_loss / total
    
    model.train()
    
    return {
        'val_acc': val_acc,
        'val_loss': val_loss
    }

In [12]:
backbone_val_results = eval(wrapped_backbone_model, val_dataloader, criterion, sop=False)
backbone_val_acc = backbone_val_results['val_acc']
backbone_val_acc

Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

0.66

In [None]:
import logging

# track = True
track = False

if track:
    import wandb
    wandb.init(project='sop')
    wandb.run.name = os.path.basename(exp_dir)

# Iterate over the data
best_val_acc = 0.0
step = 0
train_log_interval = 100
val_eval_interval = 1000

logging.basicConfig(filename=os.path.join(exp_dir, 'train.log'), level=logging.INFO)

model.train()

progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    running_loss = 0.0
    running_total = 0
    for i, batch in enumerate(train_dataloader):
        # import pdb; pdb.set_trace()
        # inputs, labels = batch
        # inputs, labels = inputs.to(device), labels.to(device)
        if not isinstance(batch['input_ids'], torch.Tensor):
            inputs = torch.stack(batch['input_ids']).transpose(0, 1).to(device)
            if 'token_type_ids' in batch:
                token_type_ids = torch.stack(batch['token_type_ids']).transpose(0, 1).to(device)
            else:
                token_type_ids = None
            attention_mask = torch.stack(batch['attention_mask']).transpose(0, 1).to(device)
            
            concatenated_rows = [torch.stack(sublist) for sublist in batch['segs']]
            segs = torch.stack(concatenated_rows).permute(2, 0, 1).to(device).float()
            # print('segs', segs.shape)
        else:
            inputs = batch['input_ids'].to(device)
            if 'token_type_ids' in batch:
                token_type_ids = batch['token_type_ids'].to(device)
            else:
                token_type_ids = None
            attention_mask = batch['attention_mask'].to(device)
            segs = batch['segs'].to(device).float()
        kwargs = {
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask,
        }
        labels = batch['label'].to(device)
            
        
        optimizer.zero_grad()
        logits = model(inputs, segs=segs, mask_batch_size=mask_batch_size, kwargs=kwargs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * labels.size(0)
        running_total += labels.size(0)
        
        if i % train_log_interval == train_log_interval - 1 or i == len(train_dataloader) - 1:
            # Print training loss every 100 batches
            curr_lr = float(optimizer.param_groups[0]['lr'])
            log_message = f'Epoch {epoch}, Batch {i + 1}, Loss {running_loss / running_total:.4f}, LR {curr_lr:.8f}'
            print(log_message)
            logging.info(log_message)
            if track:
                wandb.log({'train_loss': running_loss / running_total,
                        'lr': curr_lr,
                        'epoch': epoch,
                        'step': step})
            running_loss = 0.0
            running_total = 0
            
        if i % val_eval_interval == val_eval_interval - 1 or i == len(train_dataloader) - 1:
            val_results = eval(model, val_dataloader, criterion)
            val_acc = val_results['val_acc']
            val_loss = val_results['val_loss']
            log_message = f'Epoch {epoch}, Step {step}, Val acc {val_acc:.4f}, Val loss {val_loss:.4f}'
            print(log_message)
            logging.info(log_message)
            if track:
                wandb.log({'val_acc': val_acc,
                           'val_loss': val_loss,
                        'epoch': epoch,
                        'step': step})
            
            last_dir = os.path.join(exp_dir, 'last')
            best_dir = os.path.join(exp_dir, 'best')
            checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'step': step,
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                }
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                os.makedirs(best_dir, exist_ok=True)
                best_checkpoint_path = os.path.join(best_dir, 'checkpoint.pth')
                torch.save(checkpoint, best_checkpoint_path)
                config_best_checkpoint_path = os.path.join(best_dir, 'config.json')
                config.save_to_json(config_best_checkpoint_path)
                print(f'Best checkpoint saved at {best_checkpoint_path}')
                
                # model.save_pretrained(best_dir)
            # model.save_pretrained(last_dir)
            os.makedirs(last_dir, exist_ok=True)
            last_checkpoint_path = os.path.join(last_dir, 'checkpoint.pth')
            torch.save(checkpoint, last_checkpoint_path)
            config_last_checkpoint_path = os.path.join(last_dir, 'config.json')
            config.save_to_json(config_best_checkpoint_path)
            print(f'Last checkpoint saved at {last_checkpoint_path}')
            
        lr_scheduler.step()
        progress_bar.update(1)
        
        step += 1
        
model.save(exp_dir)

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 0, Batch 50, Loss 1.1663, LR 0.00000049
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 0, Step 49, Val acc 0.4700, Val loss 1.0321
Best checkpoint saved at ../../exps/multirc/best/checkpoint.pth
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 1, Batch 50, Loss 1.1369, LR 0.00000036
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Step 99, Val acc 0.4300, Val loss 1.0331
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 2, Batch 50, Loss 1.0119, LR 0.00000029
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 2, Step 149, Val acc 0.4400, Val loss 1.0257
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 3, Batch 50, Loss 1.0897, LR 0.00000025
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 3, Step 199, Val acc 0.4500, Val loss 1.0202
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 4, Batch 50, Loss 1.0293, LR 0.00000022
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 4, Step 249, Val acc 0.4300, Val loss 1.0219
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 5, Batch 50, Loss 0.9356, LR 0.00000020
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 5, Step 299, Val acc 0.4300, Val loss 1.0236
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 6, Batch 50, Loss 0.9700, LR 0.00000019
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 6, Step 349, Val acc 0.4400, Val loss 1.0222
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 7, Batch 50, Loss 1.0250, LR 0.00000018
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 7, Step 399, Val acc 0.4600, Val loss 1.0193
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 8, Batch 50, Loss 1.0006, LR 0.00000017
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 8, Step 449, Val acc 0.4700, Val loss 1.0152
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 9, Batch 50, Loss 0.9656, LR 0.00000016
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 9, Step 499, Val acc 0.4700, Val loss 1.0144
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 10, Batch 50, Loss 1.0919, LR 0.00000015
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 10, Step 549, Val acc 0.4600, Val loss 1.0123
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
Epoch 11, Batch 50, Loss 1.0179, LR 0.00000014
Eval ...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 11, Step 599, Val acc 0.4600, Val loss 1.0118
Last checkpoint saved at ../../exps/multirc/last/checkpoint.pth
