Original code in aigenintern/aigenintern1/23-2/MolCLR/230914_MolCLR_Finetune_MTL.ipynb

Check GPU Status: `gpustat -cuFi 1`

In [1]:
# Setup args for config
from dotmap import DotMap
args = DotMap({
    'device':'cuda:1',
    'learning_config': {
        'gin_drop_ratio': 0,    # drop ratio for GIN encoder
        'pred_layer_depth': 2,  # dimension of pred layer = feat_dim // 2
        'init_pred_lr': 2e-3, # initial learning rate for the prediction head
        'init_base_lr': 1e-4, # initial learning rate for the base GIN encoder
        'weight_decay': 1e-6,   # weight decay of Adam
        'patience': 15, # early stopping patience
    },
    'scaled' : True, # Solubility, Clearance data will be scaled for training
})

from src.tdc_constant import TDC
# NOTE: data must be in order
args.data = TDC.get_ordered_list([
    TDC.BBB,
#     TDC.Solubility,
#     TDC.CYP3A4,
#     TDC.Clearance
])
args.learning_config.num_tasks = len(args.data)

args.load_config = DotMap({
    'load': False, # If True, do not train. load from ckpts
    'name': 'MolCLR_[BBB, CYP3A4, Clearance, Solubility]-12.06_0025', # Will search for ckpts/{name}
})


## Load Data

In [2]:
# Reload modules before run
%load_ext autoreload
%autoreload 2

In [3]:
def set_model_name():
    from datetime import datetime
    li = []
    for tdc in TDC.allList:
        li.append(str(tdc) if tdc in args.data else 'X')
    name = ', '.join(li)
    return f'MolCLR_[{name}]{"_sc" if args.scaled else ""}-{datetime.now().strftime("%m.%d_%H%M")}'


model_name= args.load_config.name if args.load_config.load else set_model_name()
modelf=f'ckpts/{model_name}.pt'

model_name, modelf

('MolCLR_[BBB, X, X, X]_sc-12.13_1537',
 'ckpts/MolCLR_[BBB, X, X, X]_sc-12.13_1537.pt')

In [4]:
args.data_config = DotMap({
    'batch_size':256, # Don't know well
    'num_workers':0,
})

from src.dataset_mtl import MolTestDatasetWrapper
h_dataset = MolTestDatasetWrapper(
    tdcList = args.data,
    scaled = args.scaled,
    batch_size = args.data_config.batch_size,
    num_workers = args.data_config.num_workers,
)

# trainloader: torch_geometric.loader.DataLoader
trainloader,validloader,testloader=h_dataset.get_data_loaders()

# len = row / batch_size
len(trainloader), len(validloader), len(testloader)

(6, 1, 2)

## Create Model

In [5]:
from src.ginet_finetune import GINet_Feat_MTL, load_pre_trained_weights

model = GINet_Feat_MTL(
    pool = 'mean',
    drop_ratio = args.learning_config.gin_drop_ratio,
    pred_layer_depth = args.learning_config.pred_layer_depth,
    num_tasks = args.learning_config.num_tasks,
    pred_act = 'relu',
).to(args.device)
model = load_pre_trained_weights(model, args.device, '../aigenintern1/23-2/MolCLR/pretrained_weights/pretrained_gin_model.pth')

KeyboardInterrupt: 

In [None]:
# set different learning rates for prediction head and base

# 1) check if model_parameters are learnable
layer_list = [] # layer_list = prediction head
for name, param in model.named_parameters():
    if 'pred_head' in name:
        print(name, param.requires_grad)
        layer_list.append(name)

# 2) set different learning rates for prediction head and base
# params: prediction head
params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] in layer_list, model.named_parameters()))))
# base_params: base
base_params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] not in layer_list, model.named_parameters()))))

import torch
optimizer = torch.optim.Adam(
    [
        {'params': base_params, 'lr': args.learning_config.init_base_lr},
        {'params': params}
    ],
    args.learning_config.init_pred_lr,
    weight_decay = args.learning_config.weight_decay
)

## Set train method

In [None]:
from torch import nn

criterions = []
for tdc in args.data:
    if tdc.isRegression():
        criterions.append(nn.MSELoss())
    else:
        criterions.append(nn.BCEWithLogitsLoss())

In [None]:
import math
def mtl_loss(pred, label, criterions):
    li = []
    # loss_i = criterion(pred[:,i].squeeze(), label[:,i].squeeze())
    for i in range(args.learning_config.num_tasks):
        label_task = label[:, i].squeeze()
        pred_task = pred[:, i].squeeze()
        mask = ~torch.isnan(label_task)
        # Loss is already divided by len(pred_task[mask]) by nn.MSELoss or nn.BCELoss
        # This resolves the problem of potential bias in the loss due to different
        # numbers of labels across tasks. The normalization ensures that the loss is
        # scale-invariant with respect to the number of elements, making it fair and
        # comparable across tasks, even when they have different numbers of labels.
        x = criterions[i](pred_task[mask], label_task[mask])
        
        if not math.isnan(x): # nan if len(pred_task[mask]) == 0
            li.append(x)
        else:
            li.append(torch.tensor(0, device=args.device, dtype=torch.float32))

    # loss = mean of each mtl loss & batch
    loss = torch.mean(torch.stack(li), dim=0)
    return loss

In [None]:
def train(model, trainloader, args, optimizer=optimizer, criterions=criterions):
    model.train() # set to train mode
    train_loss = 0
    for batch in trainloader:
        batch = batch.to(args.device)
        label = batch.y

        optimizer.zero_grad()
        pred = model(batch)
        
        loss = mtl_loss(pred, label, criterions)

        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(trainloader)
    return avg_train_loss

In [None]:
def eval(model, loader, args, criterions=criterions):
    model.eval()  # Set to eval mode
    eval_loss = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(args.device)
            label = batch.y
            pred = model(batch)

            eval_loss += mtl_loss(pred, label, criterions).item()
    avg_eval_loss = eval_loss / len(loader)

    return avg_eval_loss

## Train & Validate

In [None]:
from datetime import datetime
if args.load_config.load == False:
    epoch = 0
    print_every_n_epoch = 5

    from src.EarlyStopper import EarlyStopper
    early_stopper = EarlyStopper(patience=args.learning_config.patience,printfunc=print, 
                                 verbose=False, path=modelf)
    
    with open(f'Log: {model_name}.txt', 'a') as fp:
        fp.write('Start Training\n')
        fp.write(f'{model_name}\n')
        fp.write(f'{args.device}\n')
        fp.flush()
        while True:
            epoch+=1
            train_loss=train(model,trainloader,args)
            valid_loss=eval(model,validloader,args)
            fp.write(f'[Epoch{epoch}] train_loss: {train_loss:.4f}, valid_loss: {valid_loss:.4f}. {datetime.now().strftime("%H:%M:%S")}\n')
            fp.flush()
            if (epoch % print_every_n_epoch == 0):
                pass

            early_stopper(valid_loss,model)
            if early_stopper.early_stop:
                fp.write('early stopping\n')
                fp.flush()
                break
else:
    print('Skip Training')

In [None]:
model.load_state_dict(torch.load(modelf, map_location=args.device))
print(f'loaded "{modelf}"')

In [None]:
test_loss = eval(model,testloader,args)
print(f'Final test loss: {test_loss:.4f}')

## Some Additional

In [None]:
# Learning Information
print("Learning Information")
print(f'Predicting on {args.data}')
print(f'Multi-task: predicting {args.learning_config.num_tasks} task')
print(f'Criterion: {criterions}')

In [None]:
# Example test
for tdc in args.data:
    if not tdc.isRegression():
        # Calculate Accuracy
        model.eval() # set to eval mode

        preds = []
        ys = []
        with torch.no_grad():
            for batch in testloader:
                batch = batch.to(args.device)
                label = batch.y
                pred = model(batch)
                preds.append(pred.reshape(-1))
                ys.append(label.reshape(-1))

        preds = torch.cat(preds, dim=0) # flatten into 1 dimension
        ys = torch.cat(ys, dim=0)

        pred_final = nn.Sigmoid()(preds)
        correct = (torch.abs(pred_final - ys) < 0.5).float().sum()
        accuracy = 100 * correct / len(pred_final)
        print(f'Accuracy: {accuracy:.1f}')
    else:
        model.eval()
        with torch.no_grad():
            for batch in testloader:
                batch = batch.to(args.device)
                pred = model(batch)

                print("Sample")
                print(batch.smiles[0])
                print(f'pred: {pred[0]} label: {batch.y[0]}')
                break

In [None]:
# Draw sample data
import networkx as nx
from torch_geometric.utils import to_networkx
for databatch in trainloader:
    data = databatch[1]
    print(data)
    g = to_networkx(data)
    nx.draw(g, with_labels=True)
    print(data.x[:, 0])
    # 0: H, 5: C
    print(data.smiles)
    # print(g.nodes[0]["labels"])
    break