In [1]:
!pip install -q "monai[transformers, pandas]"
%matplotlib inline

In [2]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from sklearn.metrics import roc_auc_score
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
from monai.networks.nets import Transchex
from monai.config import print_config
from monai.utils import set_determinism
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer


torch.backends.cudnn.benchmark = True
print_config()

  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.5.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /home/<username>/.local/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.14.1
Pillow version: 11.0.0
Tensorboard version: 2.18.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.1
tqdm version: 4.67.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.6
pandas version: 2.2.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: 4.40.2
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOW

In [4]:
datadir = "../data"
if not os.path.exists(datadir):
    os.makedirs(datadir)

set_determinism(seed=0)

In [5]:
class Dataset(Dataset):
    def __init__(self, dataframe, tokenizer, parent_dir, max_seq_length=512):
        
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.data = dataframe
        self.report_summary = self.data.report
        self.img_name = self.data.id
        self.targets = self.data.list

        self.preprocess = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
        self.parent_dir = parent_dir

    def __len__(self):
        return len(self.report_summary)

    def encode_features(self, sent, max_seq_length, tokenizer):
        tokens = tokenizer.tokenize(sent.strip())
        if len(tokens) > max_seq_length - 2:
            tokens = tokens[: (max_seq_length - 2)]
        tokens = ["[CLS]"] + tokens + ["[SEP]"]
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        segment_ids = [0] * len(input_ids)
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            segment_ids.append(0)
        assert len(input_ids) == max_seq_length
        assert len(segment_ids) == max_seq_length
        return input_ids, segment_ids

    def __getitem__(self, index):
        name = self.img_name[index].split(".")[0]
        img_address = os.path.join(self.parent_dir, self.img_name[index])
        image = Image.open(img_address)
        images = self.preprocess(image)
        report = str(self.report_summary[index])
        report = " ".join(report.split())
        input_ids, segment_ids = self.encode_features(report, self.max_seq_length, self.tokenizer)
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        targets = torch.tensor(self.targets[index], dtype=torch.float)
        return {
            "ids": input_ids,
            "segment_ids": segment_ids,
            "name": name,
            "targets": targets,
            "images": images,
        }

In [12]:
def load_txt_gt(add):
    txt_gt = pd.read_csv(add)
    txt_gt["list"] = txt_gt[txt_gt.columns[2:]].values.tolist()
    txt_gt = txt_gt[["id", "report", "list"]].copy()
    return txt_gt


logdir = "./logdir"
if not os.path.exists(logdir):
    os.makedirs(logdir)

parent_dir = "../data/dataset_proc/images/"
train_txt_gt = load_txt_gt("../data/dataset_proc/test.csv")
val_txt_gt = load_txt_gt("../data/dataset_proc/validation.csv")
test_txt_gt = load_txt_gt("../data/dataset_proc/test.csv")
batch_size = 32
num_workers = 0
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False)
training_set = Dataset(train_txt_gt, tokenizer, parent_dir)
train_params = {
    "batch_size": batch_size,
    "shuffle": True,
    "num_workers": num_workers,
    "pin_memory": False,
}
training_loader = DataLoader(training_set, **train_params)
valid_set = Dataset(val_txt_gt, tokenizer, parent_dir)
test_set = Dataset(test_txt_gt, tokenizer, parent_dir)
valid_params = {"batch_size": 1, "shuffle": False, "num_workers": 1, "pin_memory": True}
val_loader = DataLoader(valid_set, **valid_params)
test_loader = DataLoader(test_set, **valid_params)



In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
total_epochs = 15
eval_num = 1
lr = 1e-4
weight_decay = 1e-5

model = Transchex(
    in_channels=3,
    img_size=(256, 256),
    num_classes=14,
    patch_size=(32, 32),
    num_language_layers=2,
    num_vision_layers=2,
    num_mixed_layers=2,
).to(device)

loss_bce = torch.nn.BCELoss().cuda()
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=5, t_total=total_epochs)
scheduler.step()  # To avoid lr=0 for Epoch 0.

  state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)


In [11]:
def save_ckp(state, checkpoint_dir):
    torch.save(state, checkpoint_dir)


def compute_aucs(gt, pred, num_classes=14):
    with torch.no_grad():
        aurocs = []
        gt_np = gt
        pred_np = pred
        for i in range(num_classes):
            aurocs.append(roc_auc_score(gt_np[:, i].tolist(), pred_np[:, i].tolist()))
    return aurocs


def train(epoch):
    model.train()
    for i, data in enumerate(training_loader, 0):
        input_ids = data["ids"].cuda()
        segment_ids = data["segment_ids"].cuda()
        img = data["images"].cuda()
        targets = data["targets"].cuda()
        logits_lang = model(input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids)
        loss = loss_bce(torch.sigmoid(logits_lang), targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch}, Iteration: {i}, Loss_Tot: {loss}")


def validation(testing_loader):
    model.eval()
    targets_in = np.zeros((len(testing_loader), 14))
    preds_cls = np.zeros((len(testing_loader), 14))
    val_loss = []
    with torch.no_grad():
        for _, data in enumerate(testing_loader, 0):
            input_ids = data["ids"].cuda()
            segment_ids = data["segment_ids"].cuda()
            img = data["images"].cuda()
            targets = data["targets"].cuda()
            logits_lang = model(input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids)
            prob = torch.sigmoid(logits_lang)
            loss = loss_bce(prob, targets).item()
            targets_in[_, :] = targets.detach().cpu().numpy()
            preds_cls[_, :] = prob.detach().cpu().numpy()
            val_loss.append(loss)
        auc = compute_aucs(targets_in, preds_cls, 14)
        mean_auc = np.mean(auc)
        mean_loss = np.mean(val_loss)
        print("Evaluation Statistics: Mean AUC : {}, Mean Loss : {}".format(mean_auc, mean_loss))
    return mean_auc, mean_loss, auc


auc_val_best = 0.0
epoch_loss_values = []
metric_values = []
for epoch in range(total_epochs):
    train(epoch)
    auc_val, loss_val, _ = validation(val_loader)
    epoch_loss_values.append(loss_val)
    metric_values.append(auc_val)
    if auc_val > auc_val_best:
        checkpoint = {
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_ckp(checkpoint, logdir + "/transchex.pt")
        auc_val_best = auc_val
        print("Model Was Saved ! Current Best Validation AUC: {}    Current AUC: {}".format(auc_val_best, auc_val))
    else:
        print("Model Was NOT Saved ! Current Best Validation AUC: {}    Current AUC: {}".format(auc_val_best, auc_val))
    scheduler.step()


ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 

RuntimeError: DataLoader worker (pid(s) 562) exited unexpectedly

In [None]:
print(f"Training Finished ! Best Validation AUC: {auc_val_best:.4f} ")