##### Just for Exploration

In [1]:
from transformers import BioGptModel,ViTImageProcessor, BioGptConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel

  from .autonotebook import tqdm as notebook_tqdm


In [101]:
config_encoder = ViTConfig()
config_decoder= BioGptConfig()
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
model = VisionEncoderDecoderModel(config=config)

In [54]:
config_encoder = model.config.encoder
config_decoder = model.config.decoder
config_decoder.is_decoder = True
config_decoder.add_cross_attention = True

In [None]:
model.save_pretrained("my-model")
encoder_decoder_config = VisionEncoderDecoderConfig.from_pretrained("my-model")
model = VisionEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)

### Encoder + Decoder

In [2]:
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel
from datasets import load_dataset
from transformers import BioGptTokenizer, BioGptForCausalLM


In [3]:
def initialise(image_processor_path,tokenizer_path,processor,token):
    image_processor = processor.from_pretrained(image_processor_path)
    tokenizer = token.from_pretrained(tokenizer_path)
    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(image_processor_path, tokenizer_path)
    return image_processor,tokenizer,model

In [4]:
image_processor,tokenizer,model = initialise("google/vit-base-patch16-224-in21k","google-bert/bert-base-uncased",ViTImageProcessor,BertTokenizer)

Some weights of BertLMHeadModel were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bi

In [5]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

## Actual Implementation
###### One image example

In [6]:
from torchvision import transforms
from PIL import Image

image_path = "/Users/archita/borealis/borealis_rc/RadioCareBorealisAI/data/files/p10/p10000032/s50414267/02aa804e-bde0afdd-112c0b34-7bc16630-4e384014.jpg"

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load and preprocess the image
image = Image.open(image_path).convert("RGB")
image_tensor = preprocess(image)


pixel_values = image_processor(image_tensor, return_tensors="pt").pixel_values

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


In [11]:
pixel_values

tensor([[[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]]])

One Text Example

In [8]:
labels = tokenizer(
    """                                 FINAL REPORT
 EXAMINATION:  CHEST (PA AND LAT)
 
 INDICATION:  ___F with new onset ascites  // eval for infection
 
 TECHNIQUE:  Chest PA and lateral
 
 COMPARISON:  None.
 
 FINDINGS: 
 
 There is no focal consolidation, pleural effusion or pneumothorax.  Bilateral
 nodular opacities that most likely represent nipple shadows. The
 cardiomediastinal silhouette is normal.  Clips project over the left lung,
 potentially within the breast. The imaged upper abdomen is unremarkable.
 Chronic deformity of the posterior left sixth and seventh ribs are noted.
 
 IMPRESSION: 
 
 No acute cardiopulmonary process.
""",
    return_tensors="pt").input_ids


In [9]:
print(len(labels))

1


In [10]:
labels

tensor([[  101,  2345,  3189,  7749,  1024,  3108,  1006,  6643,  1998,  2474,
          2102,  1007, 12407,  1024,  1035,  1035,  1035,  1042,  2007,  2047,
         14447,  2004, 17847,  2015,  1013,  1013,  9345,  2140,  2005,  8985,
          6028,  1024,  3108,  6643,  1998, 11457,  7831,  1024,  3904,  1012,
          9556,  1024,  2045,  2003,  2053, 15918, 17439,  1010, 20228, 11236,
          2389,  1041,  4246, 14499,  2030,  1052,  2638,  2819, 29288,  2527,
          2595,  1012, 17758,  7293,  7934,  6728,  6305,  6447,  2008,  2087,
          3497,  5050, 14298,  6281,  1012,  1996,  4003, 18994,  2098,  7951,
         13770,  2140, 21776,  2003,  3671,  1012, 15281,  2622,  2058,  1996,
          2187, 11192,  1010,  9280,  2306,  1996,  7388,  1012,  1996,  3746,
          2094,  3356, 13878,  2003,  4895, 28578, 17007,  3085,  1012, 11888,
         13366,  2953, 16383,  1997,  1996, 15219,  2187,  4369,  1998,  5066,
         10335,  2024,  3264,  1012,  8605,  1024,  

In [12]:
# the forward function automatically creates the correct decoder_input_ids
loss = model(pixel_values=pixel_values, labels=labels).loss

In [13]:
loss

tensor(11.2009, grad_fn=<NllLossBackward0>)

### Whole Data

In [14]:
import sys
sys.path.append('/Users/archita/borealis/borealis_rc/RadioCareBorealisAI/')
from data_modules.mimic_cxr import MimicIVCXR

In [15]:
class config:
    seed = 23
    train_batch_size = 10
    valid_batch_size = 10
    test_batch_size = 10
    num_labels = 2
    num_epochs = 2


def initialise(image_processor_path,tokenizer_path,processor,token):
    image_processor = processor.from_pretrained(image_processor_path)
    tokenizer = token.from_pretrained(tokenizer_path)
    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(image_processor_path, tokenizer_path)
    return image_processor,tokenizer,model

In [16]:
import torch
from torchvision import transforms


preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


# Set your device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_root = "/Users/archita/borealis/borealis_rc/RadioCareBorealisAI/graph_report.csv"

dataset = MimicIVCXR(data_root,tokenizer=None,max_length=3000,transform=preprocess)

print(dataset.__len__())


33


In [17]:
from sklearn.model_selection import train_test_split

train_data, val_test_data = train_test_split(dataset, test_size=0.2)
val_data, test_data = train_test_split(val_test_data, test_size=0.5)

In [18]:
from torch.utils.data import DataLoader

# # Load your data and create dataloaders
train_dataloader = DataLoader(train_data, batch_size=config.train_batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=config.valid_batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=config.test_batch_size, shuffle=False)

In [19]:
image_processor, tokenizer, model = initialise("google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased", ViTImageProcessor, BertTokenizer)

Some weights of BertLMHeadModel were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bi

In [20]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

In [21]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [28]:
# clean up the text
# data loader - tokenization
# script: training + validation
# decoders - Clinical data

# def train_model(model, optimizer, train_loader, device):
from torch.nn.utils.rnn import pad_sequence

train_losses = []  
max_length = 0
for epoch in range(config.num_epochs):
    model.train()
    total_loss = 0.0
    for batch in train_dataloader:
        
        input_ids_list = []
        attention_mask_list = []

        pixel_values, inputs, output_label = batch
        max_batch_length = max(len(tokenizer(i)['input_ids']) for i in inputs)
        max_length = max(max_length, max_batch_length)


        for i in inputs:
            # Tokenize the input text
            tokenized_text = tokenizer(i, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
            
            # Get input ids
            input_ids = tokenized_text['input_ids'].squeeze(0)
            attention_mask = tokenized_text['attention_mask']


            # Append tokenized input to lists
            input_ids_list.append(input_ids)
            attention_mask_list.append(attention_mask)

        padded_input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=tokenizer.pad_token_id)

        # Compute the loss
        outputs = model(pixel_values=pixel_values, labels=padded_input_ids)

        loss = outputs.loss

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    train_losses.append(avg_loss)
    print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

Epoch 1, Average Loss: 7.6831
Epoch 2, Average Loss: 7.3433
