In [1]:
!pip install -q "monai[transformers, pandas]"
!pip install -q scikit-learn==1.0.2
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [2]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from sklearn.metrics import roc_auc_score
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
from monai.networks.nets import Transchex
from monai.config import print_config
from monai.utils import set_determinism
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer

  warn(f"Failed to load image Python extension: {e}")


In [3]:
datadir = "./monai_data"
if not os.path.exists(datadir):
    os.makedirs(datadir)

In [4]:
torch.backends.cudnn.benchmark = True

print_config()

MONAI version: 0.9.1
Numpy version: 1.22.4
Pytorch version: 1.10.2+cu102
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 356d2d2f41b473f588899d705bbc682308cee52c
MONAI __file__: c:\Users\rahma\miniconda3\envs\myenv\lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 9.0.1
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.11.3+cu102
tqdm version: 4.64.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: 1.4.1
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: 4.21.1
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the option

In [5]:
set_determinism(seed=0)

In [6]:
class MultiModalDataset(Dataset):
    def __init__(self, dataframe, tokenizer, parent_dir, max_seq_length=512):
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.data = dataframe
        self.report_summary = self.data.report
        self.img_name = self.data.id
        self.targets = self.data.list

        self.preprocess = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
        self.parent_dir = parent_dir

    def __len__(self):
        return len(self.report_summary)

    def encode_features(self, sent, max_seq_length, tokenizer):
        tokens = tokenizer.tokenize(sent.strip())
        if len(tokens) > max_seq_length - 2:
            tokens = tokens[: (max_seq_length - 2)]
        tokens = ["[CLS]"] + tokens + ["[SEP]"]
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        segment_ids = [0] * len(input_ids)
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            segment_ids.append(0)
        assert len(input_ids) == max_seq_length
        assert len(segment_ids) == max_seq_length
        return input_ids, segment_ids

    def __getitem__(self, index):
        name = self.img_name[index].split(".")[0]
        img_address = os.path.join(self.parent_dir, self.img_name[index])
        image = Image.open(img_address)
        images = self.preprocess(image)
        report = str(self.report_summary[index])
        report = " ".join(report.split())
        input_ids, segment_ids = self.encode_features(
            report, self.max_seq_length, self.tokenizer
        )
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        targets = torch.tensor(self.targets[index], dtype=torch.float)
        return {
            "ids": input_ids,
            "segment_ids": segment_ids,
            "name": name,
            "targets": targets,
            "images": images,
        }

In [10]:
def load_txt_gt(add):
    txt_gt = pd.read_csv(add)
    txt_gt["list"] = txt_gt[txt_gt.columns[2:]].values.tolist()
    txt_gt = txt_gt[["id", "report", "list"]].copy()
    return txt_gt


logdir = "./logdir"
if not os.path.exists(logdir):
    os.makedirs(logdir)

parent_dir = "./monai_data/dataset_proc/images/"
train_txt_gt = load_txt_gt("./monai_data/train.csv")
val_txt_gt = load_txt_gt("./monai_data/validation.csv")
test_txt_gt = load_txt_gt("./monai_data/test.csv")
batch_size = 32
num_workers = 8
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False)
training_set = MultiModalDataset(train_txt_gt, tokenizer, parent_dir)
train_params = {
    "batch_size": batch_size,
    "shuffle": True,
    "num_workers": num_workers,
    "pin_memory": True,
}
training_loader = DataLoader(training_set, **train_params)
valid_set = MultiModalDataset(val_txt_gt, tokenizer, parent_dir)
test_set = MultiModalDataset(test_txt_gt, tokenizer, parent_dir)
valid_params = {"batch_size": 1, "shuffle": False, "num_workers": 1, "pin_memory": True}
val_loader = DataLoader(valid_set, **valid_params)
test_loader = DataLoader(test_set, **valid_params)

Downloading vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [27]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
total_epochs = 15
eval_num = 1
lr = 1e-4
weight_decay = 1e-5

model = Transchex(
    in_channels=3,
    img_size=(2048, 1536),
    num_classes=4,
    patch_size=(64, 64),
    num_language_layers=2,
    num_vision_layers=2,
    num_mixed_layers=2,
)#.to(device)

loss_bce = torch.nn.BCELoss()#.cuda()
optimizer = torch.optim.Adam(
    params=model.parameters(), lr=lr, weight_decay=weight_decay
)
scheduler = WarmupCosineSchedule(optimizer, warmup_steps=5, t_total=total_epochs)

Downloading bert-base-uncased.tar.gz:   0%|          | 0.00/389M [00:00<?, ?B/s]

SSLError: [SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption failed or bad record mac (_ssl.c:2633)

In [23]:
def save_ckp(state, checkpoint_dir):
    torch.save(state, checkpoint_dir)


def compute_AUCs(gt, pred, num_classes=14):
    with torch.no_grad():
        AUROCs = []
        gt_np = gt
        pred_np = pred
        for i in range(num_classes):
            AUROCs.append(roc_auc_score(gt_np[:, i].tolist(), pred_np[:, i].tolist()))
    return AUROCs


def train(epoch):
    model.train()
    for i, data in enumerate(training_loader, 0):
        input_ids = data["ids"].cuda()
        segment_ids = data["segment_ids"].cuda()
        img = data["images"].cuda()
        targets = data["targets"].cuda()
        logits_lang = model(
            input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids
        )
        loss = loss_bce(torch.sigmoid(logits_lang), targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch}, Iteration: {i}, Loss_Tot: {loss}")


def validation(testing_loader):
    model.eval()
    targets_in = np.zeros((len(testing_loader), 14))
    preds_cls = np.zeros((len(testing_loader), 14))
    val_loss = []
    with torch.no_grad():
        for _, data in enumerate(testing_loader, 0):
            input_ids = data["ids"].cuda()
            segment_ids = data["segment_ids"].cuda()
            img = data["images"].cuda()
            targets = data["targets"].cuda()
            logits_lang = model(
                input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids
            )
            prob = torch.sigmoid(logits_lang)
            loss = loss_bce(prob, targets).item()
            targets_in[_, :] = targets.detach().cpu().numpy()
            preds_cls[_, :] = prob.detach().cpu().numpy()
            val_loss.append(loss)
        auc = compute_AUCs(targets_in, preds_cls, 14)
        mean_auc = np.mean(auc)
        mean_loss = np.mean(val_loss)
        print(
            "Evaluation Statistics: Mean AUC : {}, Mean Loss : {}".format(
                mean_auc, mean_loss
            )
        )
    return mean_auc, mean_loss, auc


auc_val_best = 0.0
epoch_loss_values = []
metric_values = []
for epoch in range(total_epochs):
    train(epoch)
    auc_val, loss_val, _ = validation(val_loader)
    epoch_loss_values.append(loss_val)
    metric_values.append(auc_val)
    if auc_val > auc_val_best:
        checkpoint = {
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_ckp(checkpoint, logdir + "/transchex.pt")
        auc_val_best = auc_val
        print(
            "Model Was Saved ! Current Best Validation AUC: {}    Current AUC: {}".format(
                auc_val_best, auc_val
            )
        )
    else:
        print(
            "Model Was NOT Saved ! Current Best Validation AUC: {}    Current AUC: {}".format(
                auc_val_best, auc_val
            )
        )
    scheduler.step()

NameError: name 'model' is not defined

In [16]:
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.poolmanager import PoolManager
import ssl

class MyAdapter(HTTPAdapter):
    def init_poolmanager(self, connections, maxsize, block=False):
        self.poolmanager = PoolManager(num_pools=connections,
                                       maxsize=maxsize,
                                       block=block,
                                       ssl_version=ssl.PROTOCOL_TLSv1)

In [17]:
import requests
s = requests.Session()
s.mount('https://', MyAdapter())

In [18]:
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.poolmanager import PoolManager

class SSLAdapter(HTTPAdapter):
    '''An HTTPS Transport Adapter that uses an arbitrary SSL version.'''
    def __init__(self, ssl_version=None, **kwargs):
        self.ssl_version = ssl_version

        super(SSLAdapter, self).__init__(**kwargs)

    def init_poolmanager(self, connections, maxsize, block=False):
        self.poolmanager = PoolManager(num_pools=connections,
                                       maxsize=maxsize,
                                       block=block,
                                       ssl_version=self.ssl_version)