In [None]:
pip install -U bitsandbytes rouge_score nltk pycocoevalcap transformers

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pycocoevalcap
  Downloading pycocoevalcap-1.2-py3-none-any.whl.metadata (3.2 kB)
Collecting transformers
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoProcessor,
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model
from PIL import Image

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# Constants
BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 8
MAX_LEN = 300

vision_encoder = AutoModel.from_pretrained(
    "facebook/dinov2-base"
).half().to("cuda")

for name, param in vision_encoder.named_parameters():
    if "encoder.layer.10" in name or "encoder.layer.11" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

mistral_tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    padding_side="right",
    use_fast=True
)
mistral_tokenizer.pad_token = mistral_tokenizer.eos_token

text_decoder = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
text_decoder.config.pad_token_id = mistral_tokenizer.pad_token_id
for name, param in text_decoder.named_parameters():
    if ("model.layers.30" in name or "model.layers.31" in name) and \
       param.dtype in [torch.float32, torch.float16, torch.bfloat16]:
        param.requires_grad = True
    else:
        param.requires_grad = False

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj"
    ],
    layers_to_transform=[30, 31],
    bias="none",
    task_type="CAUSAL_LM"
)

text_decoder = get_peft_model(text_decoder, lora_config)

with torch.no_grad():
    for name, param in text_decoder.named_parameters():
        if "lora_" in name and param.dtype == torch.float32:
            param.data = param.data.half()

class ProjectionLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Linear(768, 4096)
        self.gelu = nn.GELU()

    def forward(self, x):
        return self.gelu(self.proj(x))

projection = ProjectionLayer().half().to("cuda")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Training Epoch 1/10:   0%|          | 0/2678 [38:32<?, ?it/s, Train Loss=0.1379]


In [None]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import AutoTokenizer

MAX_LEN = 300

class CXRMultiViewDataset(Dataset):
    def __init__(self, root, caption_file, processor, tokenizer_name="medalpaca/medalpaca-7b"):
        self.root = root
        self.data = pd.read_csv(caption_file)
        self.processor = processor
        self.image_size = 224
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        caption = row["caption"]

        # Use your exact column names: "Image 1" and "Image 2"
        image_path_front = os.path.join(self.root, row["Image 1"])
        image_path_lat   = os.path.join(self.root, row["Image 2"])

        # Load, resize, and convert grayscale to 3-channel as needed
        front_img = Image.open(image_path_front).convert('L').resize((self.image_size, self.image_size))
        lat_img   = Image.open(image_path_lat).convert('L').resize((self.image_size, self.image_size))

        front_arr = np.array(front_img, dtype=np.float32) / 255.0
        front_arr = np.stack([front_arr, front_arr, front_arr], axis=-1)
        front_tensor = torch.tensor(front_arr).permute(2, 0, 1)

        lat_arr = np.array(lat_img, dtype=np.float32) / 255.0
        lat_arr = np.stack([lat_arr, lat_arr, lat_arr], axis=-1)
        lat_tensor = torch.tensor(lat_arr).permute(2, 0, 1)

        front_encoding = self.processor(images=front_tensor, return_tensors="pt")
        lat_encoding   = self.processor(images=lat_tensor,   return_tensors="pt")

        # Tokenize caption
        caption_tokens = self.tokenizer(
            caption,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=300,
            add_special_tokens=True
        )

        return {
            "pixel_values_front": front_encoding["pixel_values"].squeeze(0),
            "pixel_values_lat":   lat_encoding["pixel_values"].squeeze(0),
            "input_ids":          caption_tokens["input_ids"].squeeze(0),
            "attention_mask":     caption_tokens["attention_mask"].squeeze(0),
            # Optionally store raw reference for metric evaluations
            "references":         caption
        }


def multi_view_collate_fn(batch):
    return {
        "pixel_values_front": torch.stack([x["pixel_values_front"] for x in batch]),
        "pixel_values_lat":   torch.stack([x["pixel_values_lat"]   for x in batch]),
        "input_ids":          torch.stack([x["input_ids"]          for x in batch]),
        "attention_mask":     torch.stack([x["attention_mask"]     for x in batch]),
        "references":         [x["references"] for x in batch]

    }


In [None]:
train_dataset = CXRMultiViewDataset('/content/drive/MyDrive/Small_human_extracted/Images/Train/', '/content/drive/MyDrive/Small_human_extracted/Train_captions.csv', processor=AutoProcessor.from_pretrained("facebook/dinov2-base",
        do_rescale=False ))
valid_dataset = CXRMultiViewDataset('/content/drive/MyDrive/Small_human_extracted/Images/Valid/', '/content/drive/MyDrive/Small_human_extracted/Valid_captions.csv', processor=AutoProcessor.from_pretrained("facebook/dinov2-base",
        do_rescale=False ))
test_dataset = CXRMultiViewDataset('/content/drive/MyDrive/Small_human_extracted/Images/Test/', '/content/drive/MyDrive/Small_human_extracted/Test_captions.csv', processor=AutoProcessor.from_pretrained("facebook/dinov2-base",
        do_rescale=False ))

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=multi_view_collate_fn,
    pin_memory=True
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=multi_view_collate_fn,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=multi_view_collate_fn,
    pin_memory=True
)


In [None]:
import torch
import torch.nn as nn

class CXRReportGenerator(nn.Module):
    def __init__(self, vision_encoder, text_decoder, projection):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_decoder = text_decoder
        self.projection = projection

    def forward(
        self,
        pixel_values_front=None,
        pixel_values_lat=None,
        input_ids=None,
        attention_mask=None,
        labels=None,
        max_new_tokens=256
    ):
        device = pixel_values_front.device

        if input_ids is not None:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
        if labels is not None:
            labels = labels.to(device)

        vision_outputs_front = self.vision_encoder(pixel_values_front).last_hidden_state  # (B, seq_len, 768)
        front_cls = vision_outputs_front[:, 0, :]  # (B, 768)

        vision_outputs_lat = self.vision_encoder(pixel_values_lat).last_hidden_state      # (B, seq_len, 768)
        lat_cls = vision_outputs_lat[:, 0, :]  # (B, 768)

        combined_cls = 0.5 * (front_cls + lat_cls)  # shape: (B, 768)

        projected_vision = self.projection(combined_cls)  # e.g. (B, 4096)

        if input_ids is not None:
            text_embeds = self.text_decoder.model.get_input_embeddings()(input_ids)  # (B, seq_len, 4096)


            vision_prefix = projected_vision.unsqueeze(1)

            inputs_embeds = torch.cat([vision_prefix, text_embeds], dim=1)

            batch_size = pixel_values_front.size(0)
            prefix_mask = torch.ones(batch_size, 1, device=device)
            combined_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)

            if labels is not None:
                shifted_labels = torch.cat([
                    torch.full((batch_size, 1), -100, device=device),
                    labels
                ], dim=1)
            else:
                shifted_labels = None

            outputs = self.text_decoder(
                inputs_embeds=inputs_embeds,
                attention_mask=combined_attention_mask,
                labels=shifted_labels
            )
            return outputs.loss

        else:

            vision_prefix = projected_vision.unsqueeze(1)  # (B, 1, 4096)
            generated = self.text_decoder.generate(
                inputs_embeds=vision_prefix,
                max_new_tokens=max_new_tokens,
                temperature=0.7,
                top_k=50,
                do_sample=True,
                pad_token_id=self.text_decoder.config.pad_token_id
            )
            return generated


In [None]:
import bitsandbytes as bnb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CXRReportGenerator(
    vision_encoder=vision_encoder,
    text_decoder=text_decoder,
    projection=projection
)
model.to(device)

trainable_proj    = list(model.projection.parameters())
trainable_decoder = [p for p in model.text_decoder.parameters() if p.requires_grad]
trainable_encoder = [p for p in model.vision_encoder.parameters() if p.requires_grad]

optimizer = bnb.optim.Adam8bit(
    [
        {'params': trainable_proj,    'lr': 1e-4},
        {'params': trainable_decoder, 'lr': 1e-4, 'weight_decay': 0.01},
        {'params': trainable_encoder, 'lr': 1e-5}
    ],
    betas=(0.9, 0.999),
    optim_bits=8
)

In [None]:
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from pycocoevalcap.cider.cider import Cider
import nltk
import numpy as np
import gc
from tqdm import tqdm
nltk.download('punkt')

def calculate_metrics(predictions, references):
    # Tokenize for BLEU and CIDEr
    refs_bleu = [[nltk.word_tokenize(ref)] for ref in references]
    hyps_bleu = [nltk.word_tokenize(pred) for pred in predictions]

    # BLEU-4
    bleu4 = corpus_bleu(
        refs_bleu, hyps_bleu,
        weights=(0.25, 0.25, 0.25, 0.25),
        smoothing_function=SmoothingFunction().method4
    )

    # ROUGE-L
    rouge = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    rouge_l = np.mean([rouge.score(ref, hyp)['rougeL'].fmeasure
                     for ref, hyp in zip(references, predictions)])

    # CIDEr
    cider = Cider()
    refs_cider = {i: [ref] for i, ref in enumerate(references)}
    hyps_cider = {i: [hyp] for i, hyp in enumerate(predictions)}
    cider_score, _ = cider.compute_score(refs_cider, hyps_cider)

    return bleu4, rouge_l, cider_score

def train_and_validate(
    model,
    train_loader,
    val_loader,
    optimizer,
    tokenizer,
    num_epochs=5,
    device='cuda',
    save_path="best_model.pt"
):
    best_bleu4 = 0.0

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0.0

        train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}")

        for step, batch in enumerate(train_loader):
            # Clear CUDA cache if needed (optional, can help with memory fragmentation)
            gc.collect()
            torch.cuda.empty_cache()

            # Move to GPU
            pixel_values_front = batch["pixel_values_front"].to(device)
            pixel_values_lat   = batch["pixel_values_lat"].to(device)
            input_ids          = batch["input_ids"].to(device)
            attention_mask     = batch["attention_mask"].to(device)
            labels             = batch["input_ids"].to(device)

            # Forward pass
            loss = model(
                pixel_values_front=pixel_values_front,
                pixel_values_lat=pixel_values_lat,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

            if (step + 1) % 50 == 0:
                print(f"[Epoch {epoch+1}/{num_epochs} - Step {step+1}/{len(train_loader)}] "
                      f"Train Loss: {loss.item():.4f}")

            train_pbar.set_postfix({
                "Train Loss": f"{loss.item():.4f}"
            })

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Average Train Loss: {avg_train_loss:.4f}")

        model.eval()
        val_predictions = []
        val_references = []
        total_val_loss = 0.0

        val_pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}")

        with torch.no_grad():
            for batch in val_loader:
                pixel_values_front = batch["pixel_values_front"].to(device)
                pixel_values_lat   = batch["pixel_values_lat"].to(device)
                input_ids      = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["input_ids"].to(device)

                val_loss = model(
                    pixel_values_front=pixel_values_front,
                    pixel_values_lat=pixel_values_lat,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                total_val_loss += val_loss.item()

                generated_ids = model(
                    pixel_values_front=pixel_values_front,
                    pixel_values_lat=pixel_values_lat
                )

                preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

                val_predictions.extend(preds)
                references = batch["references"]  # list of strings
                val_references.extend(references)

                val_pbar.set_postfix({
                    "Val Loss": f"{val_loss.item():.4f}"
                })

            # Compute average validation loss
            avg_val_loss = total_val_loss / len(val_loader)

        # Calculate metrics on validation set
        bleu4, rouge_l, cider_score = calculate_metrics(val_predictions, val_references)

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f} | "
              f"BLEU-4: {bleu4:.4f} | "
              f"ROUGE-L: {rouge_l:.4f} | "
              f"CIDEr: {cider_score:.4f}")

        # Check if current BLEU-4 is the best so far
        if bleu4 > best_bleu4:
            best_bleu4 = bleu4
            print(f"New best BLEU-4 ({best_bleu4:.4f}) - saving model...")
            torch.save(model.state_dict(), save_path)



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
train_and_validate(
    model=model,
    train_loader=train_loader,
    val_loader=valid_loader,
    optimizer=optimizer,
    tokenizer=mistral_tokenizer,
    num_epochs=10,
    device=device,
    save_path="/content/drive/MyDrive/Small_human_extracted/best_rg_model.pt"
)


  with amp.autocast():

Training Epoch 1/10:   0%|          | 0/335 [00:00<?, ?it/s, Loss=4.1163, LR=0.00e+00][A
Training Epoch 1/10:   0%|          | 1/335 [00:00<04:17,  1.30it/s, Loss=4.1163, LR=0.00e+00][A
Training Epoch 1/10:   0%|          | 1/335 [00:01<04:17,  1.30it/s, Loss=5.0413, LR=0.00e+00][A
Training Epoch 1/10:   1%|          | 2/335 [00:01<04:15,  1.30it/s, Loss=5.0413, LR=0.00e+00][A
Training Epoch 1/10:   1%|          | 2/335 [00:02<04:15,  1.30it/s, Loss=4.4541, LR=0.00e+00][A
Training Epoch 1/10:   1%|          | 3/335 [00:02<04:14,  1.30it/s, Loss=4.4541, LR=0.00e+00][A
Training Epoch 1/10:   1%|          | 3/335 [00:03<04:14,  1.30it/s, Loss=5.1722, LR=0.00e+00][A
Training Epoch 1/10:   1%|          | 4/335 [00:03<04:18,  1.28it/s, Loss=5.1722, LR=0.00e+00][A
Training Epoch 1/10:   1%|          | 4/335 [00:03<04:18,  1.28it/s, Loss=4.8952, LR=0.00e+00][A
Training Epoch 1/10:   1%|▏         | 5/335 [00:03<04:12,  1.31it/s, Loss=4.8952, LR=0.00e+00][A
Tra

ValueError: Attempting to unscale FP16 gradients.

In [None]:
ckpt = torch.load("/content/drive/MyDrive/Small_human_extracted/best_rg_model.pt", map_location=device)

model.load_state_dict(ckpt)

  ckpt = torch.load("/content/drive/MyDrive/Small_human_extracted/best_cxr_model.pt", map_location=device)


_IncompatibleKeys(missing_keys=[], unexpected_keys=['text_decoder.base_model.model.model.layers.0.self_attn.q_proj.weight.absmax', 'text_decoder.base_model.model.model.layers.0.self_attn.q_proj.weight.quant_map', 'text_decoder.base_model.model.model.layers.0.self_attn.q_proj.weight.nested_absmax', 'text_decoder.base_model.model.model.layers.0.self_attn.q_proj.weight.nested_quant_map', 'text_decoder.base_model.model.model.layers.0.self_attn.q_proj.weight.quant_state.bitsandbytes__nf4', 'text_decoder.base_model.model.model.layers.0.self_attn.k_proj.weight.absmax', 'text_decoder.base_model.model.model.layers.0.self_attn.k_proj.weight.quant_map', 'text_decoder.base_model.model.model.layers.0.self_attn.k_proj.weight.nested_absmax', 'text_decoder.base_model.model.model.layers.0.self_attn.k_proj.weight.nested_quant_map', 'text_decoder.base_model.model.model.layers.0.self_attn.k_proj.weight.quant_state.bitsandbytes__nf4', 'text_decoder.base_model.model.model.layers.0.self_attn.v_proj.weight.ab

In [None]:
import gc
from tqdm import tqdm

test_predictions = []
test_references = []
tokenizer=mistral_tokenizer
model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Inference on Test"):
        gc.collect()
        torch.cuda.empty_cache()

        pixel_values_front = batch["pixel_values_front"].to(device)
        pixel_values_lat   = batch["pixel_values_lat"].to(device)

        generated_ids = model(
            pixel_values_front=pixel_values_front,
            pixel_values_lat=pixel_values_lat
        )

        preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        test_predictions.extend(preds)

        if "references" in batch:
            refs = batch["references"]
            test_references.extend(refs)


Inference on Test: 100%|██████████| 64/64 [31:00<00:00, 29.07s/it]


In [None]:
if len(test_references) == len(test_predictions):
    bleu4, rouge_l, cider_score = calculate_metrics(test_predictions, test_references)
    print(f"Test BLEU-4: {bleu4:.4f} | Test ROUGE-L: {rouge_l:.4f} | Test CIDEr: {cider_score:.4f}")
else:
    print("No references in test set, skipping metrics computation.")

Test BLEU-4: 0.0394 | Test ROUGE-L: 0.1410 | Test CIDEr: 0.0000


In [None]:
for i in range(2):
    print(f"--- Test Sample {i} ---")
    print(f"Generated: {test_predictions[i]}")
    if test_references:
        print(f"Reference: {test_references[i]}")
    print()


--- Test Sample 0 ---
Generated: the heart is normal in size. the mediastinal contours are within normal limits. there is no pleural effusion or pneumothorax. there is no focal airspace consolidation. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute cardiopulmonary disease. there is no evidence of acute

In [None]:
optimizer = bnb.optim.Adam8bit(
    [
        {'params': trainable_proj,    'lr': 1e-3},
        {'params': trainable_decoder, 'lr': 1e-3, 'weight_decay': 0.01},
        {'params': trainable_encoder, 'lr': 1e-3}
    ],
    betas=(0.9, 0.999),
    optim_bits=8
)

In [None]:
train_and_validate(
    model=model,
    train_loader=train_loader,
    val_loader=valid_loader,
    optimizer=optimizer,
    tokenizer=medalpaca_tokenizer,
    num_epochs=50,
    device=device,
    save_path="/content/drive/MyDrive/Small_human_extracted/best_cxr_model_2.pt"
)

Training Epoch 1/50:   0%|          | 0/224 [01:12<?, ?it/s, Train Loss=0.4095]

[Epoch 1/50 - Step 50/224] Train Loss: 0.4095


Training Epoch 1/50:   0%|          | 0/224 [02:23<?, ?it/s, Train Loss=0.3583]

[Epoch 1/50 - Step 100/224] Train Loss: 0.3583


Training Epoch 1/50:   0%|          | 0/224 [03:35<?, ?it/s, Train Loss=0.5014]

[Epoch 1/50 - Step 150/224] Train Loss: 0.5014


Training Epoch 1/50:   0%|          | 0/224 [04:47<?, ?it/s, Train Loss=0.4774]

[Epoch 1/50 - Step 200/224] Train Loss: 0.4774


Training Epoch 1/50:   0%|          | 0/224 [05:21<?, ?it/s, Train Loss=0.6585]

Epoch 1/50 - Average Train Loss: 0.6533




Validation Epoch 1/50:   0%|          | 0/32 [00:30<?, ?it/s, Val Loss=0.2177][A
Validation Epoch 1/50:   0%|          | 0/32 [00:58<?, ?it/s, Val Loss=0.2888][A
Validation Epoch 1/50:   0%|          | 0/32 [01:26<?, ?it/s, Val Loss=0.3020][A
Validation Epoch 1/50:   0%|          | 0/32 [01:54<?, ?it/s, Val Loss=0.3098][A
Validation Epoch 1/50:   0%|          | 0/32 [02:25<?, ?it/s, Val Loss=0.3222][A
Validation Epoch 1/50:   0%|          | 0/32 [02:54<?, ?it/s, Val Loss=0.3114][A
Validation Epoch 1/50:   0%|          | 0/32 [03:20<?, ?it/s, Val Loss=0.3934][A
Validation Epoch 1/50:   0%|          | 0/32 [03:49<?, ?it/s, Val Loss=0.2823][A
Validation Epoch 1/50:   0%|          | 0/32 [04:18<?, ?it/s, Val Loss=0.3626][A

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
