In [1]:
import argparse
import os
import time

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from flamby.datasets.fed_tcga_brca import (
    BATCH_SIZE,
    LR,
    NUM_EPOCHS_POOLED,
    Baseline,
    BaselineLoss,
    FedTcgaBrca,
    metric,
)
from flamby.utils import evaluate_model_on_tests
import warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
def dict_mean(dict_list):
    mean_dict = {}
    for key in dict_list[0].keys():
        mean_dict[key] = sum(d[key] for d in dict_list) / len(dict_list)
    return mean_dict

In [3]:
def train_model(
    model,
    optimizer,
    scheduler,
    dataloaders,
    dataset_sizes,
    device,
    lossfunc,
    num_epochs,
    seed,
    log,
    log_period,
):
    """Training function
    Parameters
    ----------
    model : torch model to be trained
    optimizer : torch optimizer used for training
    scheduler : torch scheduler used for training
    dataloaders : dictionary {"train": train_dataloader, "test": test_dataloader}
    dataset_sizes : dictionary {"train": len(train_dataset), "test": len(test_dataset)}
    device : device where model parameters are stored
    lossfunc : function, loss function
    num_epochs : int, number of epochs for training
    seed: int, the sint for the training
    log_period: int, the number of batches between two dumps if log is activated.
    Returns
    -------
    tuple(torch.nn.Module, float) : torch model output by training loop and
    cindex on test.
    """

    since = time.time()

    if log:
        writer = SummaryWriter(log_dir=f"./runs/seed{seed}")

    num_local_steps_per_epoch = len(dataloaders["train"].dataset) // BATCH_SIZE
    num_local_steps_per_epoch += int(
        (len(dataloaders["train"].dataset) - num_local_steps_per_epoch * BATCH_SIZE) > 0
    )
    model = model.train()
    for epoch in range(0, num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs))
        print("-" * 10)

        running_loss = 0.0
        y_true = []
        y_pred = []

        # Iterate over data.
        for idx, (X, y) in enumerate(dataloaders["train"]):
            X = X.to(device)
            y = y.to(device)
            y_true.append(y)

            optimizer.zero_grad()
            outputs = model(X)
            y_pred.append(outputs)
            loss = lossfunc(outputs, y)
            loss.backward()
            optimizer.step()

            current_step = idx + num_local_steps_per_epoch * epoch

            if log and (idx % log_period) == 0:
                writer.add_scalar("Loss/train/client", loss.item(), current_step)

            running_loss += loss.item() * X.size(0)

            scheduler.step()

        epoch_loss = running_loss / dataset_sizes["train"]
        y = torch.cat(y_true)
        y_hat = torch.cat(y_pred)
        epoch_c_index = metric(y.cpu().detach().numpy(), y_hat.cpu().detach().numpy())
        if log:
            writer.add_scalar("Loss/average-per-epoch/client", epoch_loss, epoch)
            writer.add_scalar("C-index/full-training/client", epoch_c_index, epoch)

        print(
            "{} Loss: {:.4f} c-index: {:.4f}".format("train", epoch_loss, epoch_c_index)
        )

    # Iterate over data.
    dict_cindex = evaluate_model_on_tests(model, [dataloaders["test"]], metric)

    if log:
        writer.add_scalar("Test/C-index", dict_cindex["client_test_0"], 0)

    print()
    time_elapsed = time.time() - since
    print(
        "Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60
        )
    )
    print()

    return model, dict_cindex["client_test_0"]

In [4]:
# torch.use_deterministic_algorithms(False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print("device", device)

lossfunc = BaselineLoss()
num_epochs = NUM_EPOCHS_POOLED
log = False
log_period = 10

results0 = []
results1 = []
for seed in range(10):
    torch.manual_seed(seed)
    np.random.seed(seed)

    train_dataset = FedTcgaBrca(train=True, pooled=True)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4
    )
    test_dataset = FedTcgaBrca(train=False, pooled=True)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        # drop_last=True,
    )

    dataloaders = {"train": train_dataloader, "test": test_dataloader}
    dataset_sizes = {"train": len(train_dataset), "test": len(test_dataset)}

    model = Baseline()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[3, 5, 7, 9, 11, 13, 15, 17], gamma=0.5
    )

    results0.append(evaluate_model_on_tests(model, [test_dataloader], metric))

    model, test_cindex = train_model(
        model,
        optimizer,
        scheduler,
        dataloaders,
        dataset_sizes,
        device,
        lossfunc,
        num_epochs,
        seed,
        log,
        log_period,
    )
    results1.append(test_cindex)



cpu
device cpu


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.08it/s]

Epoch 0/30
----------





train Loss: 0.1772 c-index: 0.5208
Epoch 1/30
----------
train Loss: 0.1460 c-index: 0.5984
Epoch 2/30
----------
train Loss: 0.1460 c-index: 0.6259
Epoch 3/30
----------
train Loss: 0.1417 c-index: 0.6588
Epoch 4/30
----------
train Loss: 0.1407 c-index: 0.6848
Epoch 5/30
----------
train Loss: 0.1350 c-index: 0.6970
Epoch 6/30
----------
train Loss: 0.1311 c-index: 0.7064
Epoch 7/30
----------
train Loss: 0.1366 c-index: 0.7067
Epoch 8/30
----------
train Loss: 0.1303 c-index: 0.7118
Epoch 9/30
----------
train Loss: 0.1353 c-index: 0.7213
Epoch 10/30
----------
train Loss: 0.1332 c-index: 0.7220
Epoch 11/30
----------
train Loss: 0.1302 c-index: 0.7286
Epoch 12/30
----------
train Loss: 0.1225 c-index: 0.7280
Epoch 13/30
----------
train Loss: 0.1348 c-index: 0.7363
Epoch 14/30
----------
train Loss: 0.1280 c-index: 0.7314
Epoch 15/30
----------
train Loss: 0.1316 c-index: 0.7384
Epoch 16/30
----------
train Loss: 0.1223 c-index: 0.7397
Epoch 17/30
----------
train Loss: 0.1268 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10.08it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.19it/s]

Epoch 0/30
----------





train Loss: 0.1691 c-index: 0.5277
Epoch 1/30
----------
train Loss: 0.1630 c-index: 0.5734
Epoch 2/30
----------
train Loss: 0.1604 c-index: 0.5773
Epoch 3/30
----------
train Loss: 0.1499 c-index: 0.5812
Epoch 4/30
----------
train Loss: 0.1492 c-index: 0.5848
Epoch 5/30
----------
train Loss: 0.1491 c-index: 0.5924
Epoch 6/30
----------
train Loss: 0.1480 c-index: 0.5972
Epoch 7/30
----------
train Loss: 0.1423 c-index: 0.6043
Epoch 8/30
----------
train Loss: 0.1452 c-index: 0.6140
Epoch 9/30
----------
train Loss: 0.1456 c-index: 0.6227
Epoch 10/30
----------
train Loss: 0.1492 c-index: 0.6295
Epoch 11/30
----------
train Loss: 0.1337 c-index: 0.6323
Epoch 12/30
----------
train Loss: 0.1331 c-index: 0.6421
Epoch 13/30
----------
train Loss: 0.1397 c-index: 0.6480
Epoch 14/30
----------
train Loss: 0.1366 c-index: 0.6581
Epoch 15/30
----------
train Loss: 0.1393 c-index: 0.6692
Epoch 16/30
----------
train Loss: 0.1366 c-index: 0.6819
Epoch 17/30
----------
train Loss: 0.1301 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10.11it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.83it/s]

Epoch 0/30
----------





train Loss: 0.1693 c-index: 0.5610
Epoch 1/30
----------
train Loss: 0.1668 c-index: 0.5839
Epoch 2/30
----------
train Loss: 0.1510 c-index: 0.5922
Epoch 3/30
----------
train Loss: 0.1485 c-index: 0.5944
Epoch 4/30
----------
train Loss: 0.1494 c-index: 0.5985
Epoch 5/30
----------
train Loss: 0.1494 c-index: 0.6066
Epoch 6/30
----------
train Loss: 0.1517 c-index: 0.6096
Epoch 7/30
----------
train Loss: 0.1470 c-index: 0.6216
Epoch 8/30
----------
train Loss: 0.1476 c-index: 0.6277
Epoch 9/30
----------
train Loss: 0.1425 c-index: 0.6394
Epoch 10/30
----------
train Loss: 0.1475 c-index: 0.6410
Epoch 11/30
----------
train Loss: 0.1489 c-index: 0.6520
Epoch 12/30
----------
train Loss: 0.1414 c-index: 0.6487
Epoch 13/30
----------
train Loss: 0.1424 c-index: 0.6709
Epoch 14/30
----------
train Loss: 0.1439 c-index: 0.6747
Epoch 15/30
----------
train Loss: 0.1395 c-index: 0.6881
Epoch 16/30
----------
train Loss: 0.1345 c-index: 0.6954
Epoch 17/30
----------
train Loss: 0.1341 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10.05it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.44it/s]

Epoch 0/30
----------





train Loss: 0.1703 c-index: 0.5335
Epoch 1/30
----------
train Loss: 0.1480 c-index: 0.5671
Epoch 2/30
----------
train Loss: 0.1548 c-index: 0.5790
Epoch 3/30
----------
train Loss: 0.1500 c-index: 0.5837
Epoch 4/30
----------
train Loss: 0.1505 c-index: 0.6029
Epoch 5/30
----------
train Loss: 0.1508 c-index: 0.5963
Epoch 6/30
----------
train Loss: 0.1470 c-index: 0.6200
Epoch 7/30
----------
train Loss: 0.1418 c-index: 0.6212
Epoch 8/30
----------
train Loss: 0.1427 c-index: 0.6377
Epoch 9/30
----------
train Loss: 0.1448 c-index: 0.6496
Epoch 10/30
----------
train Loss: 0.1362 c-index: 0.6668
Epoch 11/30
----------
train Loss: 0.1436 c-index: 0.6707
Epoch 12/30
----------
train Loss: 0.1341 c-index: 0.6788
Epoch 13/30
----------
train Loss: 0.1378 c-index: 0.6860
Epoch 14/30
----------
train Loss: 0.1413 c-index: 0.7002
Epoch 15/30
----------
train Loss: 0.1381 c-index: 0.7066
Epoch 16/30
----------
train Loss: 0.1327 c-index: 0.7118
Epoch 17/30
----------
train Loss: 0.1311 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.68it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.26it/s]

Epoch 0/30
----------





train Loss: 0.1606 c-index: 0.5947
Epoch 1/30
----------
train Loss: 0.1555 c-index: 0.6014
Epoch 2/30
----------
train Loss: 0.1698 c-index: 0.5993
Epoch 3/30
----------
train Loss: 0.1618 c-index: 0.6029
Epoch 4/30
----------
train Loss: 0.1546 c-index: 0.6104
Epoch 5/30
----------
train Loss: 0.1573 c-index: 0.6157
Epoch 6/30
----------
train Loss: 0.1523 c-index: 0.6124
Epoch 7/30
----------
train Loss: 0.1550 c-index: 0.6232
Epoch 8/30
----------
train Loss: 0.1507 c-index: 0.6300
Epoch 9/30
----------
train Loss: 0.1542 c-index: 0.6299
Epoch 10/30
----------
train Loss: 0.1467 c-index: 0.6420
Epoch 11/30
----------
train Loss: 0.1574 c-index: 0.6414
Epoch 12/30
----------
train Loss: 0.1488 c-index: 0.6498
Epoch 13/30
----------
train Loss: 0.1470 c-index: 0.6556
Epoch 14/30
----------
train Loss: 0.1437 c-index: 0.6645
Epoch 15/30
----------
train Loss: 0.1447 c-index: 0.6635
Epoch 16/30
----------
train Loss: 0.1454 c-index: 0.6698
Epoch 17/30
----------
train Loss: 0.1561 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.74it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.64it/s]

Epoch 0/30
----------





train Loss: 0.1518 c-index: 0.6542
Epoch 1/30
----------
train Loss: 0.1408 c-index: 0.6914
Epoch 2/30
----------
train Loss: 0.1343 c-index: 0.6974
Epoch 3/30
----------
train Loss: 0.1276 c-index: 0.7008
Epoch 4/30
----------
train Loss: 0.1301 c-index: 0.7045
Epoch 5/30
----------
train Loss: 0.1347 c-index: 0.7080
Epoch 6/30
----------
train Loss: 0.1359 c-index: 0.7050
Epoch 7/30
----------
train Loss: 0.1368 c-index: 0.7081
Epoch 8/30
----------
train Loss: 0.1289 c-index: 0.7162
Epoch 9/30
----------
train Loss: 0.1249 c-index: 0.7153
Epoch 10/30
----------
train Loss: 0.1338 c-index: 0.7207
Epoch 11/30
----------
train Loss: 0.1313 c-index: 0.7285
Epoch 12/30
----------
train Loss: 0.1252 c-index: 0.7284
Epoch 13/30
----------
train Loss: 0.1358 c-index: 0.7267
Epoch 14/30
----------
train Loss: 0.1285 c-index: 0.7324
Epoch 15/30
----------
train Loss: 0.1234 c-index: 0.7322
Epoch 16/30
----------
train Loss: 0.1378 c-index: 0.7343
Epoch 17/30
----------
train Loss: 0.1317 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.62it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.10it/s]

Epoch 0/30
----------





train Loss: 0.2294 c-index: 0.4634
Epoch 1/30
----------
train Loss: 0.2057 c-index: 0.4674
Epoch 2/30
----------
train Loss: 0.1842 c-index: 0.5174
Epoch 3/30
----------
train Loss: 0.1687 c-index: 0.5441
Epoch 4/30
----------
train Loss: 0.1619 c-index: 0.5614
Epoch 5/30
----------
train Loss: 0.1460 c-index: 0.5997
Epoch 6/30
----------
train Loss: 0.1519 c-index: 0.6214
Epoch 7/30
----------
train Loss: 0.1533 c-index: 0.6243
Epoch 8/30
----------
train Loss: 0.1484 c-index: 0.6396
Epoch 9/30
----------
train Loss: 0.1413 c-index: 0.6498
Epoch 10/30
----------
train Loss: 0.1451 c-index: 0.6551
Epoch 11/30
----------
train Loss: 0.1472 c-index: 0.6612
Epoch 12/30
----------
train Loss: 0.1526 c-index: 0.6662
Epoch 13/30
----------
train Loss: 0.1452 c-index: 0.6735
Epoch 14/30
----------
train Loss: 0.1409 c-index: 0.6782
Epoch 15/30
----------
train Loss: 0.1365 c-index: 0.6838
Epoch 16/30
----------
train Loss: 0.1356 c-index: 0.6898
Epoch 17/30
----------
train Loss: 0.1292 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.15it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.87it/s]

Epoch 0/30
----------





train Loss: 0.1571 c-index: 0.5792
Epoch 1/30
----------
train Loss: 0.1426 c-index: 0.6138
Epoch 2/30
----------
train Loss: 0.1437 c-index: 0.6425
Epoch 3/30
----------
train Loss: 0.1456 c-index: 0.6504
Epoch 4/30
----------
train Loss: 0.1375 c-index: 0.6580
Epoch 5/30
----------
train Loss: 0.1484 c-index: 0.6721
Epoch 6/30
----------
train Loss: 0.1362 c-index: 0.6746
Epoch 7/30
----------
train Loss: 0.1409 c-index: 0.6801
Epoch 8/30
----------
train Loss: 0.1427 c-index: 0.6910
Epoch 9/30
----------
train Loss: 0.1327 c-index: 0.6944
Epoch 10/30
----------
train Loss: 0.1403 c-index: 0.7018
Epoch 11/30
----------
train Loss: 0.1382 c-index: 0.7108
Epoch 12/30
----------
train Loss: 0.1364 c-index: 0.7139
Epoch 13/30
----------
train Loss: 0.1350 c-index: 0.7189
Epoch 14/30
----------
train Loss: 0.1325 c-index: 0.7250
Epoch 15/30
----------
train Loss: 0.1357 c-index: 0.7314
Epoch 16/30
----------
train Loss: 0.1265 c-index: 0.7415
Epoch 17/30
----------
train Loss: 0.1312 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.60it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.64it/s]

Epoch 0/30
----------





train Loss: 0.2119 c-index: 0.4439
Epoch 1/30
----------
train Loss: 0.1876 c-index: 0.4684
Epoch 2/30
----------
train Loss: 0.1718 c-index: 0.4849
Epoch 3/30
----------
train Loss: 0.1626 c-index: 0.5258
Epoch 4/30
----------
train Loss: 0.1591 c-index: 0.5408
Epoch 5/30
----------
train Loss: 0.1461 c-index: 0.5795
Epoch 6/30
----------
train Loss: 0.1496 c-index: 0.5932
Epoch 7/30
----------
train Loss: 0.1453 c-index: 0.6091
Epoch 8/30
----------
train Loss: 0.1395 c-index: 0.6247
Epoch 9/30
----------
train Loss: 0.1400 c-index: 0.6386
Epoch 10/30
----------
train Loss: 0.1373 c-index: 0.6404
Epoch 11/30
----------
train Loss: 0.1391 c-index: 0.6512
Epoch 12/30
----------
train Loss: 0.1507 c-index: 0.6507
Epoch 13/30
----------
train Loss: 0.1401 c-index: 0.6616
Epoch 14/30
----------
train Loss: 0.1381 c-index: 0.6703
Epoch 15/30
----------
train Loss: 0.1446 c-index: 0.6805
Epoch 16/30
----------
train Loss: 0.1345 c-index: 0.6839
Epoch 17/30
----------
train Loss: 0.1353 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.91it/s]



Training complete in 0m 9s



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.96it/s]

Epoch 0/30
----------





train Loss: 0.2436 c-index: 0.6254
Epoch 1/30
----------
train Loss: 0.2085 c-index: 0.6353
Epoch 2/30
----------
train Loss: 0.1962 c-index: 0.6252
Epoch 3/30
----------
train Loss: 0.1783 c-index: 0.6341
Epoch 4/30
----------
train Loss: 0.1775 c-index: 0.6335
Epoch 5/30
----------
train Loss: 0.1678 c-index: 0.6447
Epoch 6/30
----------
train Loss: 0.1499 c-index: 0.6495
Epoch 7/30
----------
train Loss: 0.1456 c-index: 0.6602
Epoch 8/30
----------
train Loss: 0.1499 c-index: 0.6692
Epoch 9/30
----------
train Loss: 0.1431 c-index: 0.6728
Epoch 10/30
----------
train Loss: 0.1326 c-index: 0.6828
Epoch 11/30
----------
train Loss: 0.1376 c-index: 0.6905
Epoch 12/30
----------
train Loss: 0.1386 c-index: 0.7005
Epoch 13/30
----------
train Loss: 0.1381 c-index: 0.7138
Epoch 14/30
----------
train Loss: 0.1345 c-index: 0.7179
Epoch 15/30
----------
train Loss: 0.1322 c-index: 0.7210
Epoch 16/30
----------
train Loss: 0.1390 c-index: 0.7254
Epoch 17/30
----------
train Loss: 0.1334 c-in

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10.19it/s]


Training complete in 0m 9s






In [17]:
print("Before training")
print("Test C-index ", results0)
print("Average test C-index ", dict_mean(results0))
print("After training")
print("Test C-index ", results1)
print("Average test C-index ", sum(results1) / len(results1))

Before training
Test C-index  [{'client_test_0': np.float64(0.7157509157509158)}, {'client_test_0': np.float64(0.7007326007326007)}, {'client_test_0': np.float64(0.6692307692307692)}, {'client_test_0': np.float64(0.3007326007326007)}, {'client_test_0': np.float64(0.7249084249084249)}, {'client_test_0': np.float64(0.6948717948717948)}, {'client_test_0': np.float64(0.6688644688644688)}, {'client_test_0': np.float64(0.5582417582417583)}, {'client_test_0': np.float64(0.7454212454212454)}, {'client_test_0': np.float64(0.6582417582417582)}]
Average test C-index  {'client_test_0': np.float64(0.6436996336996337)}
After training
Test C-index  [np.float64(0.8340659340659341), np.float64(0.7743589743589744), np.float64(0.8542124542124542), np.float64(0.8421245421245421), np.float64(0.8139194139194139), np.float64(0.8432234432234432), np.float64(0.8164835164835165), np.float64(0.8307692307692308), np.float64(0.823076923076923), np.float64(0.8311355311355312)]
Average test C-index  0.82633699633699

NameError: name 'm' is not defined

In [6]:
model=Baseline()

In [7]:
model

Baseline(
  (fc): Linear(in_features=39, out_features=1, bias=True)
)