In [1]:
import torch
from model import *
import numpy as np
import os
import random
import torch.optim as optim
import torch.nn as nn
from Myloader import *
import time
import torchvision.models as models
from torchmetrics.classification import MultilabelAveragePrecision
from ssl_encoder import *



In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def evaluate(model, val_loader):
    model.eval()
    test_running_loss = 0.0
    test_total = 0

    with torch.no_grad():
        record_target_label = torch.zeros(1, 19).to(device)
        record_predict_label = torch.zeros(1, 19).to(device)
        for (test_imgs, test_labels, test_dicoms) in val_loader:
            test_imgs = test_imgs.to(device)
            test_labels = test_labels.to(device)
            test_labels = test_labels.squeeze(-1)

            test_output = model(test_imgs)
            loss = criterion(test_output, test_labels)

            test_running_loss += loss.item() * test_imgs.size(0)
            test_total += test_imgs.size(0)

            record_target_label = torch.cat((record_target_label, test_labels), 0)
            record_predict_label = torch.cat((record_predict_label, test_output), 0)


        record_target_label = record_target_label[1::]
        record_predict_label = record_predict_label[1::]

        metric = MultilabelAveragePrecision(num_labels=19, average="macro", thresholds=None)
        mAP = metric(record_predict_label, record_target_label.to(torch.int32))

    return mAP, test_running_loss, test_total

In [3]:
# set_seed(123)
#     weight_dir = ""
#     if not os.path.exists(weight_dir):
#         os.makedirs(weight_dir)

epochs = 100
batch_size = 32
num_classes = 19

weight_path = "weights/"

train_path = "data/MICCAI_long_tail_train.tfrecords"
train_index = "data/MICCAI_long_tail_train.tfindex"
val_path = "data/MICCAI_long_tail_val.tfrecords"
val_index = "data/MICCAI_long_tail_val.tfindex"
opt_lr = 1e-4
weight_decay = 0
training = True
train_name = ""
val_name = ""

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
encoder = SSLEncoder(embDimension=num_classes).to(device)
# encoder = torch.load('ssl_backbone2.pth')
opt = optim.Adam(encoder.parameters(), lr=opt_lr, weight_decay = weight_decay)
train_loader = Myloader(train_path, train_index, batch_size, num_workers=0, shuffle=True)
val_loader = Myloader(val_path, val_index, batch_size, num_workers=0, shuffle=False)

criterion = nn.BCEWithLogitsLoss()

cuda


In [6]:
train_losses = []
test_losses = []

if training == True:
#         wandb.init(
#             project='chexpert mitigate bias',
#             name= train_wandb_name)
#         config = wandb.config
#         config.batch_size = batch_size
    max_map = 0
    total = 0
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):
        encoder.train()
        running_loss = 0.0
        start_time = time.time()
        count = 0

        for (imgs, labels, dicom_ids) in train_loader:
            encoder.zero_grad()
            opt.zero_grad()

            imgs = imgs.to(device)
            labels = labels.to(device)
            labels = labels.squeeze(-1)

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                output = encoder(imgs)
                # print('check')
                # print(output)
            loss = criterion(output, labels)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            running_loss += loss.item() * imgs.size(0)
            count += imgs.size(0)

            if count != 0 and count % 1024 == 0 and total == 0:
                print(f"epoch {epoch}: {count}/unknown finished / train loss: {running_loss / count}")

            elif count != 0 and count % 10 == 0 and total != 0:
                print(f"epoch {epoch}: {count}/{total} (%.2f %%) finished / train loss: {running_loss / count}" % (count/total))

        total = count
        mAP, test_running_loss, test_total = evaluate(encoder, val_loader)
        
        train_losses.append(running_loss / count)
        test_losses.append(test_running_loss)
        
        if mAP > max_map:
            max_map = mAP
            torch.save({
                'model_state_dict': encoder.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
            }, f"{weight_path}/ssl_model_best.pt")
        if epoch % 10 == 0:
            torch.save({
                    'model_state_dict': encoder.state_dict(),
                    'optimizer_state_dict': opt.state_dict(),
                }, weight_path+"/{}epoch_ssl.pt".format(epoch))

        end_time = time.time()
        duration = end_time - start_time

        print(f"epoch {epoch} / mAP: {mAP} / test loss: {test_running_loss / test_total} / duration: {duration}")





epoch 0: 1024/unknown finished / train loss: 0.43379271402955055
epoch 0: 2048/unknown finished / train loss: 0.3772197151556611
epoch 0: 3072/unknown finished / train loss: 0.3511858587153256
epoch 0: 4096/unknown finished / train loss: 0.34119522280525416
epoch 0: 5120/unknown finished / train loss: 0.33057605093345044
epoch 0: 6144/unknown finished / train loss: 0.3251800717941175
epoch 0: 7168/unknown finished / train loss: 0.32235857252297656
epoch 0: 8192/unknown finished / train loss: 0.31886282440973446
epoch 0: 9216/unknown finished / train loss: 0.3142536953609023
epoch 0: 10240/unknown finished / train loss: 0.3121962255332619
epoch 0: 11264/unknown finished / train loss: 0.3102788572961634
epoch 0: 12288/unknown finished / train loss: 0.3091770372508715
epoch 0: 13312/unknown finished / train loss: 0.3066760735729566
epoch 0: 14336/unknown finished / train loss: 0.305360821940537
epoch 0: 15360/unknown finished / train loss: 0.3046490643794338
epoch 0: 16384/unknown finishe

epoch 0: 130048/unknown finished / train loss: 0.2833803307455708
epoch 0: 131072/unknown finished / train loss: 0.28338045490454533
epoch 0: 132096/unknown finished / train loss: 0.28325368595027994
epoch 0: 133120/unknown finished / train loss: 0.2831935370591684
epoch 0: 134144/unknown finished / train loss: 0.2831644705063059
epoch 0: 135168/unknown finished / train loss: 0.2830942333005651
epoch 0: 136192/unknown finished / train loss: 0.28300263613280385
epoch 0: 137216/unknown finished / train loss: 0.28297923746129583
epoch 0: 138240/unknown finished / train loss: 0.2829371162345288
epoch 0: 139264/unknown finished / train loss: 0.2828768755459939
epoch 0: 140288/unknown finished / train loss: 0.2828370246995431
epoch 0: 141312/unknown finished / train loss: 0.2827835835624432
epoch 0: 142336/unknown finished / train loss: 0.2826921896030791
epoch 0: 143360/unknown finished / train loss: 0.282563417907139
epoch 0: 144384/unknown finished / train loss: 0.28258217886380904
epoch 

KeyboardInterrupt: 