In [1]:
import argparse
import os
import time

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from flamby.datasets.fed_tcga_brca import (
    BATCH_SIZE,
    LR,
    NUM_EPOCHS_POOLED,
    Baseline,
    BaselineLoss,
    FedTcgaBrca,
    metric,
)
from flamby.utils import evaluate_model_on_tests
import warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
def dict_mean(dict_list):
    mean_dict = {}
    for key in dict_list[0].keys():
        mean_dict[key] = sum(d[key] for d in dict_list) / len(dict_list)
    return mean_dict

In [3]:
def train_model(
    model,
    optimizer,
    scheduler,
    dataloaders,
    dataset_sizes,
    device,
    lossfunc,
    num_epochs,
    seed,
    log,
    log_period,
):
    """Training function
    Parameters
    ----------
    model : torch model to be trained
    optimizer : torch optimizer used for training
    scheduler : torch scheduler used for training
    dataloaders : dictionary {"train": train_dataloader, "test": test_dataloader}
    dataset_sizes : dictionary {"train": len(train_dataset), "test": len(test_dataset)}
    device : device where model parameters are stored
    lossfunc : function, loss function
    num_epochs : int, number of epochs for training
    seed: int, the sint for the training
    log_period: int, the number of batches between two dumps if log is activated.
    Returns
    -------
    tuple(torch.nn.Module, float) : torch model output by training loop and
    cindex on test.
    """

    since = time.time()

    if log:
        writer = SummaryWriter(log_dir=f"./runs/seed{seed}")

    num_local_steps_per_epoch = len(dataloaders["train"].dataset) // BATCH_SIZE
    num_local_steps_per_epoch += int(
        (len(dataloaders["train"].dataset) - num_local_steps_per_epoch * BATCH_SIZE) > 0
    )
    model = model.train()
    for epoch in range(0, num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs))
        print("-" * 10)

        running_loss = 0.0
        y_true = []
        y_pred = []

        # Iterate over data.
        for idx, (X, y) in enumerate(dataloaders["train"]):
            X = X.to(device)
            y = y.to(device)
            y_true.append(y)

            optimizer.zero_grad()
            outputs = model(X)
            y_pred.append(outputs)
            loss = lossfunc(outputs, y)
            loss.backward()
            optimizer.step()

            current_step = idx + num_local_steps_per_epoch * epoch

            if log and (idx % log_period) == 0:
                writer.add_scalar("Loss/train/client", loss.item(), current_step)

            running_loss += loss.item() * X.size(0)

            scheduler.step()

        epoch_loss = running_loss / dataset_sizes["train"]
        y = torch.cat(y_true)
        y_hat = torch.cat(y_pred)
        epoch_c_index = metric(y.cpu().detach().numpy(), y_hat.cpu().detach().numpy())
        if log:
            writer.add_scalar("Loss/average-per-epoch/client", epoch_loss, epoch)
            writer.add_scalar("C-index/full-training/client", epoch_c_index, epoch)

        print(
            "{} Loss: {:.4f} c-index: {:.4f}".format("train", epoch_loss, epoch_c_index)
        )

    # Iterate over data.
    dict_cindex = evaluate_model_on_tests(model, [dataloaders["test"]], metric)

    if log:
        writer.add_scalar("Test/C-index", dict_cindex["client_test_0"], 0)

    print()
    time_elapsed = time.time() - since
    print(
        "Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60
        )
    )
    print()

    return model, dict_cindex["client_test_0"]

In [4]:
# torch.use_deterministic_algorithms(False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print("device", device)

lossfunc = BaselineLoss()
num_epochs = NUM_EPOCHS_POOLED
log = False
log_period = 10

results0 = []
results1 = []
for seed in range(10):
    torch.manual_seed(seed)
    np.random.seed(seed)

    train_dataset = FedTcgaBrca(train=True, pooled=True)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4
    )
    test_dataset = FedTcgaBrca(train=False, pooled=True)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        # drop_last=True,
    )

    dataloaders = {"train": train_dataloader, "test": test_dataloader}
    dataset_sizes = {"train": len(train_dataset), "test": len(test_dataset)}

    model = Baseline()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[3, 5, 7, 9, 11, 13, 15, 17], gamma=0.5
    )

    results0.append(evaluate_model_on_tests(model, [test_dataloader], metric))

    model, test_cindex = train_model(
        model,
        optimizer,
        scheduler,
        dataloaders,
        dataset_sizes,
        device,
        lossfunc,
        num_epochs,
        seed,
        log,
        log_period,
    )
    results1.append(test_cindex)



cuda:0
device cuda:0


100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.10it/s]

Epoch 0/30
----------





train Loss: 0.1773 c-index: 0.5207
Epoch 1/30
----------
train Loss: 0.1461 c-index: 0.5983
Epoch 2/30
----------
train Loss: 0.1460 c-index: 0.6258
Epoch 3/30
----------
train Loss: 0.1418 c-index: 0.6590
Epoch 4/30
----------
train Loss: 0.1407 c-index: 0.6847
Epoch 5/30
----------
train Loss: 0.1350 c-index: 0.6970
Epoch 6/30
----------
train Loss: 0.1311 c-index: 0.7061
Epoch 7/30
----------
train Loss: 0.1366 c-index: 0.7066
Epoch 8/30
----------
train Loss: 0.1303 c-index: 0.7117
Epoch 9/30
----------
train Loss: 0.1354 c-index: 0.7213
Epoch 10/30
----------
train Loss: 0.1332 c-index: 0.7218
Epoch 11/30
----------
train Loss: 0.1303 c-index: 0.7284
Epoch 12/30
----------
train Loss: 0.1225 c-index: 0.7279
Epoch 13/30
----------
train Loss: 0.1349 c-index: 0.7362
Epoch 14/30
----------
train Loss: 0.1280 c-index: 0.7308
Epoch 15/30
----------
train Loss: 0.1317 c-index: 0.7381
Epoch 16/30
----------
train Loss: 0.1223 c-index: 0.7394
Epoch 17/30
----------
train Loss: 0.1268 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.51it/s]



Training complete in 0m 17s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.78it/s]

Epoch 0/30
----------





train Loss: 0.1693 c-index: 0.5264
Epoch 1/30
----------
train Loss: 0.1622 c-index: 0.5743
Epoch 2/30
----------
train Loss: 0.1602 c-index: 0.5796
Epoch 3/30
----------
train Loss: 0.1500 c-index: 0.5837
Epoch 4/30
----------
train Loss: 0.1493 c-index: 0.5873
Epoch 5/30
----------
train Loss: 0.1495 c-index: 0.5949
Epoch 6/30
----------
train Loss: 0.1479 c-index: 0.5988
Epoch 7/30
----------
train Loss: 0.1420 c-index: 0.6063
Epoch 8/30
----------
train Loss: 0.1452 c-index: 0.6159
Epoch 9/30
----------
train Loss: 0.1458 c-index: 0.6246
Epoch 10/30
----------
train Loss: 0.1489 c-index: 0.6304
Epoch 11/30
----------
train Loss: 0.1336 c-index: 0.6335
Epoch 12/30
----------
train Loss: 0.1333 c-index: 0.6423
Epoch 13/30
----------
train Loss: 0.1394 c-index: 0.6485
Epoch 14/30
----------
train Loss: 0.1366 c-index: 0.6582
Epoch 15/30
----------
train Loss: 0.1395 c-index: 0.6690
Epoch 16/30
----------
train Loss: 0.1365 c-index: 0.6813
Epoch 17/30
----------
train Loss: 0.1300 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.54it/s]



Training complete in 0m 17s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.86it/s]

Epoch 0/30
----------





train Loss: 0.1696 c-index: 0.5608
Epoch 1/30
----------
train Loss: 0.1674 c-index: 0.5818
Epoch 2/30
----------
train Loss: 0.1514 c-index: 0.5901
Epoch 3/30
----------
train Loss: 0.1488 c-index: 0.5908
Epoch 4/30
----------
train Loss: 0.1499 c-index: 0.5966
Epoch 5/30
----------
train Loss: 0.1500 c-index: 0.6048
Epoch 6/30
----------
train Loss: 0.1521 c-index: 0.6068
Epoch 7/30
----------
train Loss: 0.1476 c-index: 0.6187
Epoch 8/30
----------
train Loss: 0.1482 c-index: 0.6251
Epoch 9/30
----------
train Loss: 0.1427 c-index: 0.6370
Epoch 10/30
----------
train Loss: 0.1479 c-index: 0.6384
Epoch 11/30
----------
train Loss: 0.1490 c-index: 0.6494
Epoch 12/30
----------
train Loss: 0.1418 c-index: 0.6458
Epoch 13/30
----------
train Loss: 0.1429 c-index: 0.6683
Epoch 14/30
----------
train Loss: 0.1442 c-index: 0.6723
Epoch 15/30
----------
train Loss: 0.1399 c-index: 0.6857
Epoch 16/30
----------
train Loss: 0.1351 c-index: 0.6934
Epoch 17/30
----------
train Loss: 0.1342 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.47it/s]



Training complete in 0m 18s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.05it/s]

Epoch 0/30
----------





train Loss: 0.1703 c-index: 0.5335
Epoch 1/30
----------
train Loss: 0.1480 c-index: 0.5674
Epoch 2/30
----------
train Loss: 0.1547 c-index: 0.5798
Epoch 3/30
----------
train Loss: 0.1500 c-index: 0.5837
Epoch 4/30
----------
train Loss: 0.1505 c-index: 0.6034
Epoch 5/30
----------
train Loss: 0.1508 c-index: 0.5963
Epoch 6/30
----------
train Loss: 0.1471 c-index: 0.6204
Epoch 7/30
----------
train Loss: 0.1418 c-index: 0.6213
Epoch 8/30
----------
train Loss: 0.1426 c-index: 0.6379
Epoch 9/30
----------
train Loss: 0.1448 c-index: 0.6497
Epoch 10/30
----------
train Loss: 0.1363 c-index: 0.6671
Epoch 11/30
----------
train Loss: 0.1435 c-index: 0.6710
Epoch 12/30
----------
train Loss: 0.1341 c-index: 0.6791
Epoch 13/30
----------
train Loss: 0.1378 c-index: 0.6860
Epoch 14/30
----------
train Loss: 0.1413 c-index: 0.7005
Epoch 15/30
----------
train Loss: 0.1382 c-index: 0.7071
Epoch 16/30
----------
train Loss: 0.1328 c-index: 0.7121
Epoch 17/30
----------
train Loss: 0.1311 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.97it/s]



Training complete in 0m 17s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.21it/s]

Epoch 0/30
----------





train Loss: 0.1605 c-index: 0.5944
Epoch 1/30
----------
train Loss: 0.1554 c-index: 0.6018
Epoch 2/30
----------
train Loss: 0.1696 c-index: 0.5991
Epoch 3/30
----------
train Loss: 0.1617 c-index: 0.6026
Epoch 4/30
----------
train Loss: 0.1545 c-index: 0.6101
Epoch 5/30
----------
train Loss: 0.1572 c-index: 0.6157
Epoch 6/30
----------
train Loss: 0.1522 c-index: 0.6124
Epoch 7/30
----------
train Loss: 0.1548 c-index: 0.6227
Epoch 8/30
----------
train Loss: 0.1506 c-index: 0.6291
Epoch 9/30
----------
train Loss: 0.1542 c-index: 0.6300
Epoch 10/30
----------
train Loss: 0.1464 c-index: 0.6422
Epoch 11/30
----------
train Loss: 0.1573 c-index: 0.6408
Epoch 12/30
----------
train Loss: 0.1487 c-index: 0.6504
Epoch 13/30
----------
train Loss: 0.1471 c-index: 0.6558
Epoch 14/30
----------
train Loss: 0.1437 c-index: 0.6647
Epoch 15/30
----------
train Loss: 0.1446 c-index: 0.6630
Epoch 16/30
----------
train Loss: 0.1453 c-index: 0.6695
Epoch 17/30
----------
train Loss: 0.1560 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.52it/s]



Training complete in 0m 17s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.45it/s]

Epoch 0/30
----------





train Loss: 0.1520 c-index: 0.6538
Epoch 1/30
----------
train Loss: 0.1412 c-index: 0.6906
Epoch 2/30
----------
train Loss: 0.1345 c-index: 0.6968
Epoch 3/30
----------
train Loss: 0.1278 c-index: 0.7005
Epoch 4/30
----------
train Loss: 0.1301 c-index: 0.7039
Epoch 5/30
----------
train Loss: 0.1349 c-index: 0.7069
Epoch 6/30
----------
train Loss: 0.1361 c-index: 0.7037
Epoch 7/30
----------
train Loss: 0.1369 c-index: 0.7072
Epoch 8/30
----------
train Loss: 0.1290 c-index: 0.7148
Epoch 9/30
----------
train Loss: 0.1251 c-index: 0.7140
Epoch 10/30
----------
train Loss: 0.1340 c-index: 0.7202
Epoch 11/30
----------
train Loss: 0.1313 c-index: 0.7261
Epoch 12/30
----------
train Loss: 0.1255 c-index: 0.7267
Epoch 13/30
----------
train Loss: 0.1359 c-index: 0.7261
Epoch 14/30
----------
train Loss: 0.1285 c-index: 0.7319
Epoch 15/30
----------
train Loss: 0.1235 c-index: 0.7315
Epoch 16/30
----------
train Loss: 0.1380 c-index: 0.7337
Epoch 17/30
----------
train Loss: 0.1317 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.67it/s]



Training complete in 0m 17s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.25it/s]

Epoch 0/30
----------





train Loss: 0.2300 c-index: 0.4625
Epoch 1/30
----------
train Loss: 0.2063 c-index: 0.4665
Epoch 2/30
----------
train Loss: 0.1846 c-index: 0.5159
Epoch 3/30
----------
train Loss: 0.1692 c-index: 0.5425
Epoch 4/30
----------
train Loss: 0.1623 c-index: 0.5603
Epoch 5/30
----------
train Loss: 0.1462 c-index: 0.5987
Epoch 6/30
----------
train Loss: 0.1523 c-index: 0.6201
Epoch 7/30
----------
train Loss: 0.1537 c-index: 0.6236
Epoch 8/30
----------
train Loss: 0.1487 c-index: 0.6389
Epoch 9/30
----------
train Loss: 0.1416 c-index: 0.6489
Epoch 10/30
----------
train Loss: 0.1453 c-index: 0.6545
Epoch 11/30
----------
train Loss: 0.1476 c-index: 0.6605
Epoch 12/30
----------
train Loss: 0.1526 c-index: 0.6657
Epoch 13/30
----------
train Loss: 0.1451 c-index: 0.6727
Epoch 14/30
----------
train Loss: 0.1412 c-index: 0.6771
Epoch 15/30
----------
train Loss: 0.1366 c-index: 0.6828
Epoch 16/30
----------
train Loss: 0.1357 c-index: 0.6892
Epoch 17/30
----------
train Loss: 0.1291 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.82it/s]



Training complete in 0m 17s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.22it/s]

Epoch 0/30
----------





train Loss: 0.1571 c-index: 0.5809
Epoch 1/30
----------
train Loss: 0.1426 c-index: 0.6148
Epoch 2/30
----------
train Loss: 0.1437 c-index: 0.6431
Epoch 3/30
----------
train Loss: 0.1455 c-index: 0.6516
Epoch 4/30
----------
train Loss: 0.1376 c-index: 0.6591
Epoch 5/30
----------
train Loss: 0.1484 c-index: 0.6726
Epoch 6/30
----------
train Loss: 0.1361 c-index: 0.6755
Epoch 7/30
----------
train Loss: 0.1409 c-index: 0.6812
Epoch 8/30
----------
train Loss: 0.1427 c-index: 0.6915
Epoch 9/30
----------
train Loss: 0.1327 c-index: 0.6949
Epoch 10/30
----------
train Loss: 0.1403 c-index: 0.7021
Epoch 11/30
----------
train Loss: 0.1382 c-index: 0.7112
Epoch 12/30
----------
train Loss: 0.1363 c-index: 0.7141
Epoch 13/30
----------
train Loss: 0.1350 c-index: 0.7191
Epoch 14/30
----------
train Loss: 0.1325 c-index: 0.7254
Epoch 15/30
----------
train Loss: 0.1357 c-index: 0.7318
Epoch 16/30
----------
train Loss: 0.1264 c-index: 0.7418
Epoch 17/30
----------
train Loss: 0.1311 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.22it/s]



Training complete in 0m 17s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.35it/s]

Epoch 0/30
----------





train Loss: 0.2130 c-index: 0.4417
Epoch 1/30
----------
train Loss: 0.1900 c-index: 0.4625
Epoch 2/30
----------
train Loss: 0.1731 c-index: 0.4784
Epoch 3/30
----------
train Loss: 0.1634 c-index: 0.5201
Epoch 4/30
----------
train Loss: 0.1603 c-index: 0.5342
Epoch 5/30
----------
train Loss: 0.1465 c-index: 0.5745
Epoch 6/30
----------
train Loss: 0.1500 c-index: 0.5888
Epoch 7/30
----------
train Loss: 0.1458 c-index: 0.6047
Epoch 8/30
----------
train Loss: 0.1400 c-index: 0.6219
Epoch 9/30
----------
train Loss: 0.1405 c-index: 0.6336
Epoch 10/30
----------
train Loss: 0.1371 c-index: 0.6371
Epoch 11/30
----------
train Loss: 0.1398 c-index: 0.6478
Epoch 12/30
----------
train Loss: 0.1508 c-index: 0.6477
Epoch 13/30
----------
train Loss: 0.1404 c-index: 0.6596
Epoch 14/30
----------
train Loss: 0.1382 c-index: 0.6689
Epoch 15/30
----------
train Loss: 0.1448 c-index: 0.6784
Epoch 16/30
----------
train Loss: 0.1345 c-index: 0.6827
Epoch 17/30
----------
train Loss: 0.1356 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.21it/s]



Training complete in 0m 17s



100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.49it/s]

Epoch 0/30
----------





train Loss: 0.2437 c-index: 0.6254
Epoch 1/30
----------
train Loss: 0.2086 c-index: 0.6351
Epoch 2/30
----------
train Loss: 0.1963 c-index: 0.6251
Epoch 3/30
----------
train Loss: 0.1784 c-index: 0.6339
Epoch 4/30
----------
train Loss: 0.1776 c-index: 0.6333
Epoch 5/30
----------
train Loss: 0.1679 c-index: 0.6444
Epoch 6/30
----------
train Loss: 0.1500 c-index: 0.6494
Epoch 7/30
----------
train Loss: 0.1457 c-index: 0.6600
Epoch 8/30
----------
train Loss: 0.1500 c-index: 0.6693
Epoch 9/30
----------
train Loss: 0.1431 c-index: 0.6726
Epoch 10/30
----------
train Loss: 0.1327 c-index: 0.6826
Epoch 11/30
----------
train Loss: 0.1377 c-index: 0.6904
Epoch 12/30
----------
train Loss: 0.1387 c-index: 0.6999
Epoch 13/30
----------
train Loss: 0.1382 c-index: 0.7134
Epoch 14/30
----------
train Loss: 0.1346 c-index: 0.7176
Epoch 15/30
----------
train Loss: 0.1323 c-index: 0.7205
Epoch 16/30
----------
train Loss: 0.1391 c-index: 0.7253
Epoch 17/30
----------
train Loss: 0.1335 c-in

100%|██████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.33it/s]


Training complete in 0m 17s






In [17]:
print("Before training")
print("Test C-index ", results0)
print("Average test C-index ", dict_mean(results0))
print("After training")
print("Test C-index ", results1)
print("Average test C-index ", sum(results1) / len(results1))

Before training
Test C-index  [{'client_test_0': np.float64(0.7157509157509158)}, {'client_test_0': np.float64(0.7007326007326007)}, {'client_test_0': np.float64(0.6692307692307692)}, {'client_test_0': np.float64(0.3007326007326007)}, {'client_test_0': np.float64(0.7249084249084249)}, {'client_test_0': np.float64(0.6948717948717948)}, {'client_test_0': np.float64(0.6688644688644688)}, {'client_test_0': np.float64(0.5582417582417583)}, {'client_test_0': np.float64(0.7454212454212454)}, {'client_test_0': np.float64(0.6582417582417582)}]
Average test C-index  {'client_test_0': np.float64(0.6436996336996337)}
After training
Test C-index  [np.float64(0.8340659340659341), np.float64(0.7743589743589744), np.float64(0.8542124542124542), np.float64(0.8421245421245421), np.float64(0.8139194139194139), np.float64(0.8432234432234432), np.float64(0.8164835164835165), np.float64(0.8307692307692308), np.float64(0.823076923076923), np.float64(0.8311355311355312)]
Average test C-index  0.82633699633699

NameError: name 'm' is not defined

In [6]:
model=Baseline()

In [7]:
model

Baseline(
  (fc): Linear(in_features=39, out_features=1, bias=True)
)