In [1]:
import torch
import wandb
from tqdm.auto import tqdm
from torchmetrics import AUROC, ROC, Accuracy
from dataset import ImageDatasetFromParquet
import torch_geometric
import torchvision.transforms as T
import torchvision
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = "cuda"
TRAIN_BATCH_SIZE = 64
VAL_BATCH_SIZE = 64
TEST_BATCH_SIZE = 64
NUM_EPOCHS = 5

In [3]:
required_transform = [
    #T.Resize(224),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    # T.RandomAdjustSharpness(0.5, p=0.1),
]

In [4]:
run_0_path = "/scratch/gsoc/parquet_ds/QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet"
run_1_path = "/scratch/gsoc/parquet_ds/QCDToGGQQ_IMGjet_RH1all_jet0_run1_n47540.test.snappy.parquet"
run_2_path = "/scratch/gsoc/parquet_ds/QCDToGGQQ_IMGjet_RH1all_jet0_run2_n55494.test.snappy.parquet"


run_0_ds = ImageDatasetFromParquet(run_0_path, transforms=required_transform, return_regress=True)
run_1_ds = ImageDatasetFromParquet(run_1_path, transforms=required_transform, return_regress=True)
run_2_ds = ImageDatasetFromParquet(run_2_path, transforms=required_transform, return_regress=True)

combined_dset = torch.utils.data.ConcatDataset([run_0_ds, run_1_ds, run_2_ds])

In [5]:
TEST_SIZE = 0.2
VAL_SIZE = 0.15

test_size = int(len(combined_dset) * TEST_SIZE)
val_size = int(len(combined_dset) * VAL_SIZE)
train_size = len(combined_dset) - val_size - test_size

train_dset, val_dset, test_dset = torch.utils.data.random_split(
    combined_dset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42),
)

In [6]:
test_dset.required_transforms = [T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]

In [7]:
train_loader = torch.utils.data.DataLoader(train_dset, shuffle=True, batch_size=TRAIN_BATCH_SIZE, pin_memory=True, num_workers=16)
val_loader = torch.utils.data.DataLoader(val_dset, shuffle=False, batch_size=VAL_BATCH_SIZE, pin_memory=True, num_workers=16)
test_loader = torch.utils.data.DataLoader(test_dset, shuffle=False, batch_size=TEST_BATCH_SIZE, num_workers=16)

In [8]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [9]:
class RegressModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        in_features = self.model.fc.in_features
        self.model.fc = torch.nn.Identity()

        self.out_lin = torch.nn.Sequential(
          torch.nn.Linear(in_features + 1, in_features // 2, bias=True),
          torch.nn.BatchNorm1d(in_features // 2),
          torch.nn.SiLU(),
          torch.nn.Dropout(),
          torch.nn.Linear(in_features // 2, in_features // 4, bias=True),
          torch.nn.BatchNorm1d(in_features // 4),
          torch.nn.SiLU(),
          torch.nn.Dropout(),
          torch.nn.Linear(in_features // 4, 1, bias=True),
        )

    def forward(self, X, pt):
        out = self.model(X)
        out = torch.cat([out, pt.unsqueeze(-1)], dim=1)
        return self.out_lin(out)

In [10]:
def get_model(device):
    model = RegressModel(
        model=torchvision.models.resnet50(pretrained=True)
    )
    model = model.to(device)
    
    return model

In [11]:
def get_optimizer(model, lr):
    return torch.optim.Adam(model.parameters(), lr=lr)

In [12]:
def train(num_epochs, model, criterion, optimizer, train_loader, val_loader, device):
    best_model = copy.deepcopy(model).to('cpu', non_blocking=True)
    best_val_loss = float('inf')
    val_loss_avg_meter = AverageMeter()
    
    for epoch in range(num_epochs):
        model.train()
        tqdm_iter = tqdm(train_loader, total=len(train_loader))
        tqdm_iter.set_description(f"Epoch {epoch}")

        for it, batch in enumerate(tqdm_iter):
            optimizer.zero_grad()

            X, pt, m0 = batch['X_jets'].float(), batch['pt'].float(), batch['m0'].float()

            X = X.to(device, non_blocking=True)
            pt = pt.to(device, non_blocking=True)
            m0 = m0.to(device, non_blocking=True)

            out = model(X, pt)

            loss = criterion(out, m0.unsqueeze(-1))

            tqdm_iter.set_postfix(loss=loss.item())
            wandb.log({
                "train_mse_loss": loss.item(),
                "train_step": (it * TRAIN_BATCH_SIZE) + epoch * train_size
            })

            loss.backward()
            optimizer.step()

        model.eval()
        val_tqdm_iter = tqdm(val_loader, total=len(val_loader))
        val_tqdm_iter.set_description(f"Validation Epoch {epoch}")
        val_loss_avg_meter.reset()

        for it, batch in enumerate(val_tqdm_iter):
            with torch.no_grad():
                X, pt, m0 = batch['X_jets'].float(), batch['pt'].float(), batch['m0'].float()

                X = X.to(device, non_blocking=True)
                pt = pt.to(device, non_blocking=True)
                m0 = m0.to(device, non_blocking=True)

                out = model(X, pt)

                loss = criterion(out, m0.unsqueeze(-1))

                val_tqdm_iter.set_postfix(loss=loss.item())
                wandb.log({
                    "val_mse_loss": loss.item(),
                    "val_step": (it * VAL_BATCH_SIZE) + epoch * val_size
                })
                val_loss_avg_meter.update(loss.item(), out.size(0))

        if val_loss_avg_meter.avg < best_val_loss:
            best_model = copy.deepcopy(model).to('cpu', non_blocking=True)
            best_val_loss = val_loss_avg_meter.avg
                
    del model

    return best_model.to(device, non_blocking=True)

In [13]:
def test(model, test_loader, criterion, device):
    model.eval()
    test_loss_avg_meter = AverageMeter()
    tqdm_iter = tqdm(test_loader, total=len(test_loader))
    
    pred_list = []
    ground_truth_list = []
    
    
    for it, batch in enumerate(tqdm_iter):
        with torch.no_grad():
            X, pt, m0 = batch['X_jets'].float(), batch['pt'].float(), batch['m0'].float()

            X = X.to(device, non_blocking=True)
            pt = pt.to(device, non_blocking=True)
            m0 = m0.to(device, non_blocking=True)

            out = model(X, pt)

            loss = criterion(out, m0.unsqueeze(-1))
                        
            tqdm_iter.set_postfix(loss=loss.item())
            
            test_loss_avg_meter.update(loss.item(), out.size(0))
            
    return test_loss_avg_meter.avg

In [14]:
def main(run_name):
    wandb.init(name=run_name, project='gsoc-submission')
    
    model = get_model(DEVICE)
    
    opt = get_optimizer(model, lr=3e-4)
    criterion = torch.nn.MSELoss()
    
    model = train(NUM_EPOCHS, model, criterion, opt, train_loader, val_loader, DEVICE)
    test_loss = test(model, test_loader, criterion, DEVICE)
    print(f"Model on Test dataset: Loss: {test_loss}")
    
    wandb.finish()
    
    return model

In [15]:
model = main('task_2_regress_resnet')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjai-bardhan[0m (use `wandb login --relogin` to force relogin)


Epoch 0: 100%|██████████| 1415/1415 [07:06<00:00,  3.32it/s, loss=11]  
Validation Epoch 0: 100%|██████████| 327/327 [01:46<00:00,  3.08it/s, loss=11.7]
Epoch 1: 100%|██████████| 1415/1415 [07:05<00:00,  3.33it/s, loss=11.6]
Validation Epoch 1: 100%|██████████| 327/327 [01:47<00:00,  3.03it/s, loss=6.95]
Epoch 2: 100%|██████████| 1415/1415 [06:57<00:00,  3.39it/s, loss=10.7]
Validation Epoch 2: 100%|██████████| 327/327 [01:42<00:00,  3.18it/s, loss=5.04]
Epoch 3: 100%|██████████| 1415/1415 [06:52<00:00,  3.43it/s, loss=15.6]
Validation Epoch 3: 100%|██████████| 327/327 [01:38<00:00,  3.31it/s, loss=9.03]
Epoch 4: 100%|██████████| 1415/1415 [06:59<00:00,  3.37it/s, loss=9.12]
Validation Epoch 4: 100%|██████████| 327/327 [01:44<00:00,  3.13it/s, loss=6.45]
100%|██████████| 436/436 [02:17<00:00,  3.18it/s, loss=6.95]


Model on Test dataset: Loss: 5.6690891087771895



0,1
train_mse_loss,█▆▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_mse_loss,█▆▄█▆▅▇▆▃▃▄▁▆▄▂▄▃▃▁▄▄▂▂▃▆▄▅▃▇▄▅▅▃▁▂▂▂▁▁▂
val_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
train_mse_loss,9.11817
train_step,452696.0
val_mse_loss,6.44981
val_step,104444.0


In [16]:
torch.save(model.state_dict(), "task_2_regress_model.pt")