In [1]:
import json
import os
import argparse

import numpy as np
import random
import torch
from torch import nn, optim
from tqdm.auto import tqdm

from pathlib import Path
from torch.utils.data import DataLoader, Subset
import sys
sys.path.append('../../lib/exlib/src')
from exlib.modules.sop import SOPImageCls, SOPConfig, get_chained_attr, get_inverse_sqrt_with_separate_heads_schedule_with_warmup
from exlib.datasets.cosmogrid import CosmogridDataset, CNNModel

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 42
if SEED != -1:
    # Torch RNG
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    # Python RNG
    np.random.seed(SEED)
    random.seed(SEED)

In [3]:
# model paths
backbone_model_name = '../../data/cosmogrid/CNN_mass_maps.pth'

# data paths
TRAIN_DATA_DIR = '../../data/cosmogrid'
VAL_DATA_DIR = '../../data/cosmogrid'
mask_path = '../../data/processed/cosmogrid/masks/X_maps_Cosmogrid_100k_watershed_diagonal.npy'

# training args
batch_size = 16
lr = 0.0005
num_epochs = 20
warmup_steps = 2000
mask_batch_size = 64
# num_heads = 1
num_heads = 4

# experiment args
# exp_dir = '../../exps/cosmogrid'
exp_dir = '../../exps/cosmogrid_4h'
os.makedirs(exp_dir, exist_ok=True)

In [4]:
config = SOPConfig(
    json_file='../configs/cosmogrid.json',
    num_heads=num_heads,
)

backbone_model = CNNModel(config.num_labels)
state_dict = torch.load(backbone_model_name)
backbone_model.load_state_dict(state_dict=state_dict)
processor = None

In [5]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader


# Load the dataset
train_size, val_size = -1, -1
# train_size = 100
# val_size = 100
train_dataset = CosmogridDataset(root_dir=TRAIN_DATA_DIR, split='train', data_size=train_size,
                                 inputs_filename='X_maps_Cosmogrid_100k.npy',
                                 labels_filename='y_maps_Cosmogrid_100k.npy',
                                 mask_path=mask_path,
                                num_masks_max=config.num_masks_max)
val_dataset = CosmogridDataset(root_dir=TRAIN_DATA_DIR, split='val', data_size=val_size,
                               inputs_filename='X_maps_Cosmogrid_100k.npy',
                                 labels_filename='y_maps_Cosmogrid_100k.npy',
                                 mask_path=mask_path,
                                num_masks_max=config.num_masks_max)

# Use subset for testing purpose
# num_data = 100
# train_dataset = Subset(train_dataset, range(num_data))
# val_dataset = Subset(val_dataset, range(num_data))

# Create a DataLoader to batch and shuffle the data
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# samples used for training: 80000
# samples used for validation: 10000
# samples used for testing: 10000
# total samples: 100000
x shape (80000, 66, 66) (10000, 66, 66) (10000, 66, 66)
y shape (80000, 6) (10000, 6) (10000, 6)
masks shape (80000, 66, 66) (10000, 66, 66) (10000, 66, 66)
-- ALL --
max 0.7257571922558966
min -0.034935039865926346
-- SPLIT train --
max 0.7257571922558966
min -0.034935039865926346
Finished loading 80000 train images ... 
# samples used for training: 80000
# samples used for validation: 10000
# samples used for testing: 10000
# total samples: 100000
x shape (80000, 66, 66) (10000, 66, 66) (10000, 66, 66)
y shape (80000, 6) (10000, 6) (10000, 6)
masks shape (80000, 66, 66) (10000, 66, 66) (10000, 66, 66)
-- ALL --
max 0.6323861033355062
min -0.031224769235240986
-- SPLIT val --
max 0.6323861033355062
min -0.031224769235240986
Finished loading 10000 val images ... 


In [6]:
backbone_model = backbone_model.to(device)

In [7]:
model = SOPImageCls(config, backbone_model)
model = model.to(device)

deep copy class weights


In [8]:
from transformers import get_scheduler

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
num_training_steps = len(train_dataloader) * num_epochs
train_rep_step_size = int(num_training_steps / config.num_heads)
lr_scheduler = get_inverse_sqrt_with_separate_heads_schedule_with_warmup(
            optimizer=optimizer, 
            num_warmup_steps=warmup_steps,
            num_steps_per_epoch=train_rep_step_size,
            num_heads=config.num_heads
        )
criterion = nn.MSELoss()

In [9]:
def eval(model, dataloader, criterion, postprocess=lambda x:x):
    print('Eval ...')
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        progress_bar_eval = tqdm(range(len(dataloader)))
        for i, batch in enumerate(dataloader):
            # Now you can use `inputs` and `labels` in your training loop.
            inputs, labels, masks, _ = batch
            inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device, dtype=torch.float)
            masks = masks.to(device)

            outputs = model(inputs)
            logits = postprocess(outputs)
            
            # val loss
            loss = criterion(logits, labels)
            total_loss += loss.item() * labels.size(0)
            
            total += labels.size(0)
            
            progress_bar_eval.update(1)
    
    # val_acc = correct / total
    val_loss = total_loss / total
    
    model.train()
    
    return {
        # 'val_acc': val_acc,
        'val_loss': val_loss
    }

In [10]:
backbone_val_results = eval(backbone_model, val_dataloader, criterion, postprocess=lambda x:x.logits)
backbone_val_loss = backbone_val_results['val_loss']
backbone_val_loss

Eval ...


  0%|          | 0/625 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


0.008118031978979708

In [None]:
import logging

track = True
# track = False
early_stop = False
early_stop_met = False

if track:
    import wandb
    wandb.init(project='sop')
    wandb.run.name = os.path.basename(exp_dir)

# Iterate over the data
best_val_loss = np.inf
step = 0
train_log_interval = 100
val_eval_interval = 1000

logging.basicConfig(filename=os.path.join(exp_dir, 'train.log'), level=logging.INFO)

model.train()

progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    running_loss = 0.0
    running_total = 0
    for i, batch in enumerate(train_dataloader):
        inputs, labels, masks, _ = batch
        inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device, dtype=torch.float)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        train_rep_step = step // train_rep_step_size
        logits = model(inputs, segs=masks, epoch=train_rep_step, mask_batch_size=mask_batch_size)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * labels.size(0)
        running_total += labels.size(0)
        
        if i % train_log_interval == train_log_interval - 1 or i == len(train_dataloader) - 1:
            # Print training loss every 100 batches
            curr_lr = float(optimizer.param_groups[0]['lr'])
            log_message = f'Epoch {epoch}, Batch {i + 1}, Loss {running_loss / running_total:.4f}, LR {curr_lr:.8f}'
            print(log_message)
            logging.info(log_message)
            if track:
                wandb.log({'train_loss': running_loss / running_total,
                        'lr': curr_lr,
                        'epoch': epoch,
                        'step': step})
            running_loss = 0.0
            running_total = 0
            
        if i % val_eval_interval == val_eval_interval - 1 or i == len(train_dataloader) - 1:
            val_results = eval(model, val_dataloader, criterion)
            val_loss = val_results['val_loss']
            log_message = f'Epoch {epoch}, Step {step}, Val loss {val_loss:.4f}'
            print(log_message)
            logging.info(log_message)
            if track:
                wandb.log({
                    # 'val_acc': val_acc,
                           'val_loss': val_loss,
                        'epoch': epoch,
                        'step': step})
            
            last_dir = os.path.join(exp_dir, 'last')
            best_dir = os.path.join(exp_dir, 'best')
            checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'step': step,
                    'val_loss': val_loss,
                }
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                os.makedirs(best_dir, exist_ok=True)
                best_checkpoint_path = os.path.join(best_dir, 'checkpoint.pth')
                torch.save(checkpoint, best_checkpoint_path)
                config_best_checkpoint_path = os.path.join(best_dir, 'config.json')
                config.save_to_json(config_best_checkpoint_path)
                print(f'Best checkpoint saved at {best_checkpoint_path}')
                
            os.makedirs(last_dir, exist_ok=True)
            last_checkpoint_path = os.path.join(last_dir, 'checkpoint.pth')
            torch.save(checkpoint, last_checkpoint_path)
            config_last_checkpoint_path = os.path.join(last_dir, 'config.json')
            config.save_to_json(config_best_checkpoint_path)
            print(f'Last checkpoint saved at {last_checkpoint_path}')
            
            if early_stop and val_loss <= backbone_val_loss:
                early_stop_met = True
                break
            
        lr_scheduler.step()
        progress_bar.update(1)
        
        step += 1
        
    if early_stop_met:
        break
        
model.save(exp_dir)

[34m[1mwandb[0m: Currently logged in as: [33mfallcat[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 0, Batch 100, Loss 0.1137, LR 0.00002475
Epoch 0, Batch 200, Loss 0.0824, LR 0.00004975
Epoch 0, Batch 300, Loss 0.0327, LR 0.00007475
Epoch 0, Batch 400, Loss 0.0250, LR 0.00009975
Epoch 0, Batch 500, Loss 0.0203, LR 0.00012475
Epoch 0, Batch 600, Loss 0.0204, LR 0.00014975
Epoch 0, Batch 700, Loss 0.0194, LR 0.00017475
Epoch 0, Batch 800, Loss 0.0194, LR 0.00019975
Epoch 0, Batch 900, Loss 0.0184, LR 0.00022475
Epoch 0, Batch 1000, Loss 0.0178, LR 0.00024975
Eval ...


  0%|          | 0/625 [00:00<?, ?it/s]