In [7]:

import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
from src.benchmark.toy_dataset import Task_redundant ,Task_synergy,Task_combination
from pytorch_lightning.trainer import seed_everything
from src.baseline.baseline import O_Estimator
seed_everything(42)

def get_samples(test_loader,mod_list,device):
        data ={ mod :torch.Tensor().to(device) for mod in mod_list} 
        for batch in test_loader:
            for mod in mod_list:
                data[mod]  = torch.cat([data[mod],batch[mod].to(device) ] )
        return data

Seed set to 42


In [13]:
nb_mod = 6
dim=5
r = {} 
#for sigma in [2.0,1.5,1.0,0.8,0.6,0.4,0.2,0.1,0.01]: 
def test_sgima(rho):
    
    task = Task_combination(tasks= [Task_synergy(nb_var=3, rho=rho ,dim = dim),
                                    Task_synergy(nb_var=3, rho=rho ,dim = dim)] ,dim=dim)
    N = 100 * 1000

    d_train, d_test = task.get_torch_dataset(N,10000,dim=dim,rescale=False)

    train_loader = DataLoader(d_train, batch_size=64,shuffle=True,
                                num_workers=8, drop_last=True)

    test_loader = DataLoader(d_test, batch_size=64,
                                shuffle= False,
                                num_workers=8, drop_last=False)

    mod_list={ "x"+ str(i) : dim for i in range(nb_mod) }

    test_samples = test_loader

    model = O_Estimator(
                    dims= [dim for i in range(nb_mod) ],
                    test_samples= test_samples,
                    gt = task.get_summary(), 
                    hidden_size=24,
                    mi_estimator="InfoNCE",
                    lr=1e-3,
                    test_epoch= 20,
                    )

    CHECKPOINT_DIR = "trained_models/"
    tb_logger =  TensorBoardLogger(save_dir = CHECKPOINT_DIR,name="baseline"+str(dim))
    trainer = pl.Trainer( logger= tb_logger,
                        accelerator='gpu', devices= 1,
                            max_epochs= 10, 
                            #num_sanity_val_steps=0,
                            #strategy="ddp",
                            default_root_dir = CHECKPOINT_DIR,
                        )
    trainer.fit(model=model, train_dataloaders=train_loader,val_dataloaders=test_loader  )
    r ={} 
    r["gt"] = task.get_summary()
    model.eval()
    r ["e"] = model.forward(model.test_samples)
    return r
    


In [None]:
r=test_sgima(0.5)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name                 | Type       | Params
----------------------------------------------------
0 | mi_estimator_list_tc | ModuleList | 41.2 K
1 | mi_estimator_list_s  | ModuleList | 110 K 
----------------------------------------------------
151 K     Trainable params
0         Non-trainable params
151 K     Total params
0.605     Total estimated model params size (MB)


Sanity Checking: |                                                 | 0/? [00:00<?, ?it/s]

Training: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                      | 0/? [00:00<?, ?it/s]


  0%|                                                            | 0/157 [00:00<?, ?it/s][A
  1%|▎                                                   | 1/157 [00:00<00:49,  3.15it/s][A
  8%|███▉                                               | 12/157 [00:00<00:04, 36.04it/s][A
 15%|███████▊                                           | 24/157 [00:00<00:02, 61.62it/s][A
 22%|███████████▎                                       | 35/157 [00:00<00:01, 76.12it/s][A
 29%|██████████████▉                                    | 46/157 [00:00<00:01, 85.98it/s][A
 36%|██████████████████▌                                | 57/157 [00:00<00:01, 91.85it/s][A
 43%|██████████████████████                             | 68/157 [00:00<00:00, 96.06it/s][A
 51%|█████████████████████████▍                        | 80/157 [00:01<00:00, 100.59it/s][A
 58%|████████████████████████████▉                     | 91/157 [00:01<00:00, 101.67it/s][A
 65%|███████████████████████████████▊                 | 102/157 [00:0

Validation: |                                                      | 0/? [00:00<?, ?it/s]

Validation: |                                                      | 0/? [00:00<?, ?it/s]

In [None]:
r["gt"]["o_inf"]

In [None]:
r["e"]["o_if"]