# INSTALL DEPENDENCIES

In [None]:
!pip install torchinfo
!pip install albumentations

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
import os
from torch.utils.data import Dataset, DataLoader
import cv2
import pandas as pd
from torchvision import transforms
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_curve, auc, confusion_matrix, jaccard_score
import gc
#from torch.cuda.amp import autocast, GradScaler
from torchinfo import summary
import seaborn as sns
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random
import numpy as np
import pickle
import torch.nn.functional as F
from collections import OrderedDict
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.amp.syncfree.adamw as adamw

'import torch_xla\nimport torch_xla.core.xla_model as xm\nimport torch_xla.amp.syncfree.adamw as adamw'

# DATALOADER

This block defines the Dataloader used for the backbone.


In [None]:
def creator_path_slices(patient_dir):

    return [os.path.join(patient_dir,s) for s in os.listdir(patient_dir)]

In [None]:
def get_augmentations():
  rotation = random.randint(-10,10)
  return A.Compose([A.Affine(rotate=(rotation,rotation),p=1),
                    A.Affine(
                        scale=(1 - 0.15, 1 + 0.15),
                        translate_percent=(-0.1, 0.1),
                        p=0.8, border_mode=cv2.BORDER_CONSTANT),
                    A.RandomBrightnessContrast(
                    brightness_limit=0.1,
                    contrast_limit=0.1,
                    p=0.5
                ),
                A.GaussianBlur(blur_limit=(3, 5), p=0.1),
                A.GaussNoise(std_range=(0.05, 0.1), p=0.2),
    ])

In [None]:
class TomosynthesisDataset(Dataset):
    def __init__(self, csv_path, images_folder, augmentations=None):
        self.metadata = pd.read_csv(csv_path)
        self.images_folder = images_folder
        self.augmentations = augmentations
        self.data = []

        for i in range(len(self.metadata)):
            patient_id = self.metadata.iloc[i]['PatientID']
            target_slice = self.metadata.iloc[i]['Slice_representativo']
            intermediate_folder = self.metadata.iloc[i]['Intermedia']
            classification = self.metadata.iloc[i]["Clasificacion"]
            patient_folder = os.path.join(images_folder, patient_id)

            entry = {
                'patient_id': patient_id,
                'target_slice': target_slice,
                'intermediate_folder': os.path.join(patient_folder, intermediate_folder),
                'classification': classification
            }

            if classification == 1:

              self.data.append((entry, False))  # False --> original image

              self.data.append((entry, True))   # True --> augmented image
            else:
              self.data.append((entry,False))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry, apply_augmentation = self.data[idx]

        distancia = 100
        target_slice_subsample_position = 0
        intermediate_folder = entry['intermediate_folder']
        target_slice = entry['target_slice']

        paths_slices = creator_path_slices(intermediate_folder)
        if len(paths_slices) != 27:
            raise ValueError(f"Expected 27 slices, got {len(paths_slices)} for patient {entry['patient_id']}")

        for i in range(len(paths_slices)):
            name_slice = paths_slices[i].split('/')[-1]
            num_slice = int(name_slice.split('.')[0])

            if abs(target_slice - num_slice) < distancia:
                target_slice_subsample_position = i
                distancia = abs(target_slice - num_slice)

        slices = []
        transform = transforms.ToTensor()
        for path in paths_slices:
            if not os.path.exists(path):
                raise FileNotFoundError(f"Slice not found: {path}")

            slice_img = cv2.imread(path, cv2.IMREAD_COLOR)
            slice_img = cv2.cvtColor(slice_img, cv2.COLOR_BGR2RGB)

            if apply_augmentation and self.augmentations is not None:
                slice_img = self.augmentations(image=slice_img)['image']

            slice_tensor = transform(slice_img)
            slices.append(slice_tensor)

        slices_tensor = torch.stack(slices)
        return slices_tensor, torch.tensor(target_slice_subsample_position, dtype=torch.long)


## Prueba dataloader

In [None]:
"""train_dataset = TomosynthesisDataset(
      csv_path='/content/drive/MyDrive/slices.csv',
      images_folder='/content/drive/MyDrive/PruebaColab',
      augmentations = get_augmentations()
    )"""

"train_dataset = TomosynthesisDataset(\n      csv_path='/content/drive/MyDrive/slices.csv',\n      images_folder='/content/drive/MyDrive/PruebaColab',\n      augmentations = get_augmentations()\n    )"

In [None]:

"""# Crear el DataLoader
data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Imprimir algunos labels para verificar
for i, (slices, labels) in enumerate(data_loader):
    print(f"Lote {i + 1} - Labels: {labels.size(0)}")

    if i == 5:  # Mostrar solo los primeros 5 lotes
        break"""

'# Crear el DataLoader\ndata_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n\n# Imprimir algunos labels para verificar\nfor i, (slices, labels) in enumerate(data_loader):\n    print(f"Lote {i + 1} - Labels: {labels.size(0)}")\n\n    if i == 5:  # Mostrar solo los primeros 5 lotes\n        break'

In [None]:
"""import matplotlib.pyplot as plt
import numpy as np
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Obtener un batch del dataloader
data_iter = iter(dataloader)
slices_tensor, labels = next(data_iter)  # slices_tensor tiene forma [batch, 27, C, H, W]

# Seleccionar un ejemplo del batch (por ejemplo, el primer paciente)
index = 0  # Puedes cambiar esto para ver otros casos
sample_slices = slices_tensor[index]  # [27, C, H, W]

# Visualizar algunas slices del volumen
fig, axes = plt.subplots(1, 5, figsize=(15, 5))  # Mostrar 5 slices
for i, ax in enumerate(axes):
    slice_img = sample_slices[i].permute(1, 2, 0).numpy()  # Convertir tensor a imagen
    slice_img = (slice_img * 255).astype(np.uint8)  # Reescalar si es necesario
    ax.imshow(slice_img)
    ax.set_title(f"Slice {i}")
    ax.axis("off")

plt.show()
"""

'import matplotlib.pyplot as plt\nimport numpy as np\ndataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n\n# Obtener un batch del dataloader\ndata_iter = iter(dataloader)\nslices_tensor, labels = next(data_iter)  # slices_tensor tiene forma [batch, 27, C, H, W]\n\n# Seleccionar un ejemplo del batch (por ejemplo, el primer paciente)\nindex = 0  # Puedes cambiar esto para ver otros casos\nsample_slices = slices_tensor[index]  # [27, C, H, W]\n\n# Visualizar algunas slices del volumen\nfig, axes = plt.subplots(1, 5, figsize=(15, 5))  # Mostrar 5 slices\nfor i, ax in enumerate(axes):\n    slice_img = sample_slices[i].permute(1, 2, 0).numpy()  # Convertir tensor a imagen\n    slice_img = (slice_img * 255).astype(np.uint8)  # Reescalar si es necesario\n    ax.imshow(slice_img)\n    ax.set_title(f"Slice {i}")\n    ax.axis("off")\n\nplt.show()\n'

In [None]:
"""for batch_slices, batch_targets in dataloader:
    # batch_slices: [batch_size, num_slices, channels, height, width]
    # batch_targets: [batch_size]
    print(batch_slices.shape)
    print(batch_targets)
"""

'for batch_slices, batch_targets in dataloader:\n    # batch_slices: [batch_size, num_slices, channels, height, width]\n    # batch_targets: [batch_size]\n    print(batch_slices.shape)\n    print(batch_targets)\n'

# ARCHITECTURE

## Attention Module

This module is used to take the most representative slice in a tomosynthesis stack

In [None]:
class RepresentativeSliceDetector(nn.Module):
    def __init__(self, feature_dim=512, hidden_dim=256, dropout=0.1):
        super(RepresentativeSliceDetector, self).__init__()
        self.feature_dim = feature_dim

        self.attention = nn.Sequential(
            nn.Linear(self.feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        self.layer_norm = nn.LayerNorm(self.feature_dim)

    def forward(self, feature_map):
        """
        Returns the attention logits and unweighted feature maps.

        Args:
            feature_map: Tensor of size (batch_size, num_slices, feature_dim, H, W)

        Returns:
            feature_map: Tensor of size (batch_size, num_slices, feature_dim, H, W)
            logits: Tensor of size (batch_size, num_slices, 1) (logits for each slice)
        """

        features = feature_map.mean(dim=(3, 4))  # (batch_size, num_slices, feature_dim)

        features = self.layer_norm(features)
        logits = self.attention(features)

        return logits.squeeze(-1)

## Feature Extractor

Returns the feature map of each slice

In [None]:
class ResNet50FeatureExtractor(nn.Module):
    def __init__(self, weights=ResNet50_Weights.DEFAULT, dropout=0.1):
        super(ResNet50FeatureExtractor, self).__init__()
        self.resnet50 = models.resnet50(weights=weights)
        self.resnet50.to(torch.bfloat16)
        self.stem = nn.Sequential(
            self.resnet50.conv1,
            self.resnet50.bn1,
            self.resnet50.relu,
            self.resnet50.maxpool  # 1/4 resolution
        )

        self.layer1 = self.resnet50.layer1  # C2 (this layer is not going to be used in the FPN)
        self.layer2 = self.resnet50.layer2  # C3
        self.layer3 = self.resnet50.layer3  # C4
        self.layer4 = self.resnet50.layer4  # C5

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        batch_size, num_slices, C, H, W = x.size()
        x = x.view(batch_size * num_slices, C, H, W)

        x = self.stem(x)
        c2 = self.layer1(x)  # (B*N, 256, H/4, W/4)
        c3 = self.layer2(c2)  # (B*N, 512, H/8, W/8)
        c4 = self.layer3(c3)  # (B*N, 1024, H/16, W/16)
        c5 = self.layer4(c4)  # (B*N, 2048, H/32, W/32)

        def reshape(f): return f.view(batch_size, num_slices, f.size(1), f.size(2), f.size(3))

        return {
            'c3': reshape(c3),
            'c4': reshape(c4),
            'c5': reshape(c5)
        }


## MRI with FPN

The attention module is feed with characteristics of the c5 layer after been passed thought a FPN

In [None]:
class ModifiedResNet50Backbone(nn.Module):
    def __init__(self, out_channels=512):
        super(ModifiedResNet50Backbone, self).__init__()

        try:
          self.device = xm.xla_device()

        except:
          self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


        self.out_channels = out_channels
        self.feature_extractor = ResNet50FeatureExtractor()
        self.attention_module = RepresentativeSliceDetector(feature_dim=self.out_channels)
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=[512, 1024, 2048],
            out_channels=self.out_channels
        ).to(self.device)

    def forward(self, x):
        x = x.to(self.device)

        top_indices_logits = None

        features = self.feature_extractor(x)
        c3, c4, c5 = features["c3"], features["c4"], features["c5"]

        B, N, C3, H3, W3 = c3.shape
        _, _, C4, H4, W4 = c4.shape
        _, _, C5, H5, W5 = c5.shape

        c3_fpn = c3.view(B * N, C3, H3, W3)
        c4_fpn = c4.view(B * N, C4, H4, W4)
        c5_fpn = c5.view(B * N, C5, H5, W5)


        fpn_out = self.fpn(OrderedDict({
            "0": c3_fpn,
            "1": c4_fpn,
            "2": c5_fpn
        }))
        features_fpn = {k: v.view(B, N, self.out_channels, v.shape[2], v.shape[3]) for k, v in fpn_out.items()} #recuperar shape original

        top_indices_logits = self.attention_module(features_fpn['2'])

        return top_indices_logits


## MRI without FPN

The attention module is feed with characteristics of the c5 layer directly from the feature extractor

In [None]:
class ModifiedResNet50Backbone(nn.Module): 
    def __init__(self, out_channels=2048):
        super(ModifiedResNet50Backbone, self).__init__()

        try:
          self.device = xm.xla_device()
        except:
          self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.out_channels = out_channels
        self.feature_extractor = ResNet50FeatureExtractor()
        self.attention_module = RepresentativeSliceDetector(feature_dim=self.out_channels)

    def forward(self, x):
        x = x.to(self.device)

        top_indices_logits = None

        features = self.feature_extractor(x)
        c3, c4, c5 = features["c3"], features["c4"], features["c5"]

        top_indices_logits = self.attention_module(c5)

        return top_indices_logits


## Data flow of the backbone

In [None]:
modelo = ModifiedResNet50Backbone()
summary(modelo, input_size=(8,27, 3, 256, 256))

Layer (type:depth-idx)                        Output Shape              Param #
ModifiedResNet50Backbone                      [8, 27]                   --
├─ResNet50FeatureExtractor: 1-1               [8, 27, 2048, 8, 8]       2,049,000
│    └─Sequential: 2-1                        [216, 64, 64, 64]         --
│    │    └─Conv2d: 3-1                       [216, 64, 128, 128]       9,408
│    │    └─BatchNorm2d: 3-2                  [216, 64, 128, 128]       128
│    │    └─ReLU: 3-3                         [216, 64, 128, 128]       --
│    │    └─MaxPool2d: 3-4                    [216, 64, 64, 64]         --
│    └─Sequential: 2-2                        [216, 256, 64, 64]        --
│    │    └─Bottleneck: 3-5                   [216, 256, 64, 64]        75,008
│    │    └─Bottleneck: 3-6                   [216, 256, 64, 64]        70,400
│    │    └─Bottleneck: 3-7                   [216, 256, 64, 64]        70,400
│    └─Sequential: 2-3                        [216, 512, 32, 32]        

# LOAD DATALOADERS

## Load data from 1 big dataset


In [None]:
df = pd.read_csv('csv_path')
labels = df['Clasificacion'].tolist()
indices = list(range(len(df)))

train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=labels, random_state=42)

base_dataset = TomosynthesisDataset(
    csv_path='csv_path',
    images_folder='image_folder_path',
    augmentations=None  
)

train_dataset = Subset(TomosynthesisDataset(
    csv_path='csv_path',
    images_folder='/content/drive/MyDrive/test+valid_png',
    augmentations=get_augmentations()
), train_idx)

val_dataset = Subset(TomosynthesisDataset(
    csv_path='csv_path',
    images_folder='image_folder_path',
    augmentations=None
), val_idx)


train_dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    pin_memory=True,
    num_workers=2,

  )

val_dataloader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    pin_memory=True,
    num_workers=2
  )



print(f'Tamaño train batches: {len(train_dataloader)}.')
print(f'Tamaño val batches: {len(val_dataloader)}.')

Tamaño train batches: 289.
Tamaño val batches: 73.


## Load data from a split Dataset

In [None]:
def get_augmentations():
  rotation = random.randint(-10,10)
  return A.Compose([A.Affine(rotate=(rotation,rotation),p=1)
    ])

train_dataset = TomosynthesisDataset(
      csv_path= "csv_train_path",
      images_folder='image_folder_train_path',
      augmentations = get_augmentations()
    )

val_dataset = TomosynthesisDataset(
  csv_path='csv_test_path',
  images_folder='image_folder_val_path',
  )

train_dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    pin_memory=True,
    num_workers=2,

  )


val_dataloader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    pin_memory=True,
    num_workers=2
  )


print(f'Tamaño train batches: {len(train_dataloader)}.')
print(f'Tamaño val batches: {len(val_dataloader)}.')

Tamaño train batches: 101.
Tamaño val batches: 20.


# TRAINING

## GPU training

In [None]:
def train_model(model, train_dataloader, val_dataloader, num_epochs=20, lr=1e-5, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    loss_fn = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=7)

    best_val_acc = 0
    training_history = {
        "train_loss": [], "train_acc": [], "train_precision": [], "train_recall": [], "train_iou": [],
        "val_loss": [], "val_acc": [], "val_precision": [], "val_recall": [], "val_iou": [],
        "val_probs": [], "val_labels": []
    }

    patience = 10
    epochs_no_improve = 0

    scaler = GradScaler()

    for epoch in range(num_epochs):
        ### 🔹 Training 🔹 ###
        model.train()
        train_loss, train_correct, total_samples = 0.0, 0, 0
        all_preds, all_labels = [], []

        train_loader_tqdm = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", leave=False)

        for images, labels in train_loader_tqdm:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            with autocast():
                feature_maps, logits = model(images)
                loss = loss_fn(logits, labels.long())

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item() * images.size(0)
            preds = torch.argmax(logits, dim=1)
            train_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            train_loader_tqdm.set_postfix(loss=loss.item())

        train_loss /= total_samples
        train_acc = train_correct / total_samples
        train_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        train_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
        train_iou = jaccard_score(all_labels, all_preds, average='weighted', zero_division=0)

        training_history["train_loss"].append(train_loss)
        training_history["train_acc"].append(train_acc)
        training_history["train_precision"].append(train_precision)
        training_history["train_recall"].append(train_recall)
        training_history["train_iou"].append(train_iou)

        ### 🔹 Validation 🔹 ###
        model.eval()
        val_loss, val_correct, total_val_samples = 0.0, 0, 0
        all_val_preds, all_val_labels, all_val_probs = [], [], []

        with torch.no_grad():
            val_loader_tqdm = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", leave=False)

            for images, labels in val_loader_tqdm:
                images, labels = images.to(device), labels.to(device)
                with autocast():
                    feature_maps, logits = model(images)
                    loss = loss_fn(logits, labels.long())

                val_loss += loss.item() * images.size(0)
                preds = torch.argmax(logits, dim=1)
                val_correct += (preds == labels).sum().item()
                total_val_samples += labels.size(0)

                all_val_preds.extend(preds.cpu().numpy())
                all_val_labels.extend(labels.cpu().numpy())
                all_val_probs.extend(torch.softmax(logits, dim=1).cpu().numpy())

                val_loader_tqdm.set_postfix(loss=loss.item())

        val_loss /= total_val_samples
        val_acc = val_correct / total_val_samples
        val_precision = precision_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)
        val_recall = recall_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)
        val_iou = jaccard_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)

        training_history["val_loss"].append(val_loss)
        training_history["val_acc"].append(val_acc)
        training_history["val_precision"].append(val_precision)
        training_history["val_recall"].append(val_recall)
        training_history["val_iou"].append(val_iou)
        training_history["val_probs"].append(np.array(all_val_probs))
        training_history["val_labels"].append(np.array(all_val_labels))

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train Prec: {train_precision:.4f}, Train Rec: {train_recall:.4f}, Train IoU: {train_iou:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val IoU: {val_iou:.4f}")

        scheduler.step(val_loss)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict()
            epochs_no_improve = 0
            torch.save(best_model_state, "best_model.pth")
            print(f"Best model in epoch {epoch}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print("Early stopping activado.")
            break

        with open("training_history.pkl", "wb") as f:
            pickle.dump(training_history, f)

    return training_history

## ENTRENAMIENTO TPU

In [None]:
def train_model(model, train_dataloader, val_dataloader, num_epochs=90):
    device = xm.xla_device()
    model = model.to(device)

    optimizer = adamw.AdamW([
                {'params': model.feature_extractor.parameters(), 'lr': 1e-5},
                {'params': model.attention_module.parameters(), 'lr': 1e-4}
            ], weight_decay=1e-5)

    loss_fn = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=8)

    best_val_loss = 100
    training_history = {
        "train_loss": [], "train_acc": [], "train_precision": [], "train_recall": [], "train_iou": [],
        "val_loss": [], "val_acc": [], "val_precision": [], "val_recall": [], "val_iou": [],
        "val_probs": [], "val_labels": [], "train_preds": [], "train_labels": [], "val_preds": []
    }

    patience = 17
    epochs_no_improve = 0

    for epoch in range(num_epochs):

      #-----------------------------------TRAIN------------------------------------------
      model.train()
      train_loss, train_correct, total_samples = 0.0, 0, 0
      all_train_preds, all_train_labels = [], []

      train_loader_tqdm = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", leave=True)

      for images, labels in train_loader_tqdm:

          images = images.to(device)
          labels = labels.to(device)
          optimizer.zero_grad()

          with torch_xla.amp.autocast(device=device, dtype=torch.bfloat16):
              logits = model(images)
              loss = loss_fn(logits, labels.long())

          loss.backward()
          xm.optimizer_step(optimizer)

          train_loss += loss.item() * images.size(0)
          preds = torch.argmax(logits, dim=1)
          train_correct += ((abs(preds - labels) <= 2)).sum().item()
          total_samples += labels.size(0)

          all_train_preds.extend(preds.cpu().numpy())
          all_train_labels.extend(labels.cpu().numpy())

          train_loader_tqdm.set_postfix(loss=loss.item())

          del logits, loss, images, labels, preds
          xm.mark_step()

      train_loss /= total_samples
      train_acc = train_correct / total_samples
      train_precision = precision_score(all_train_labels, all_train_preds, average='weighted', zero_division=0)
      train_recall = recall_score(all_train_labels, all_train_preds, average='weighted', zero_division=0)
      train_iou = jaccard_score(all_train_labels, all_train_preds, average='weighted', zero_division=0)

      training_history["train_loss"].append(train_loss)
      training_history["train_acc"].append(train_acc)
      training_history["train_precision"].append(train_precision)
      training_history["train_recall"].append(train_recall)
      training_history["train_iou"].append(train_iou)
      training_history["train_preds"].append(np.array(all_train_preds))
      training_history["train_labels"].append(np.array(all_train_labels))

      #------------------------------------------VALIDATION-----------------------------------------
      model.eval()
      val_loss, val_correct, total_val_samples = 0.0, 0, 0
      all_val_preds, all_val_labels, all_val_probs = [], [], []

      with torch.no_grad():
          val_loader_tqdm = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", leave=True)

          for images, labels in val_loader_tqdm:
              images, labels = images.to(device), labels.to(device)

              with torch_xla.amp.autocast(device=device, dtype=torch.bfloat16):
                  logits = model(images)
                  loss = loss_fn(logits, labels.long())

              val_loss += loss.item() * images.size(0)
              preds = torch.argmax(logits, dim=1)

              val_correct += ((abs(preds - labels) <= 2)).sum().item()
              total_val_samples += labels.size(0)

              all_val_preds.extend(preds.cpu().numpy())
              all_val_labels.extend(labels.cpu().numpy())
              all_val_probs.extend(logits.cpu().to(torch.float32).numpy())

              val_loader_tqdm.set_postfix(loss=loss.item())

              del logits, loss, images, labels, preds
              xm.mark_step()

      val_loss /= total_val_samples
      val_acc = val_correct / total_val_samples
      val_precision = precision_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)
      val_recall = recall_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)
      val_iou = jaccard_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)

      training_history["val_loss"].append(val_loss)
      training_history["val_acc"].append(val_acc)
      training_history["val_precision"].append(val_precision)
      training_history["val_recall"].append(val_recall)
      training_history["val_iou"].append(val_iou)
      training_history["val_probs"].append(np.array(all_val_probs))
      training_history["val_preds"].append(np.array(all_val_preds))
      training_history["val_labels"].append(np.array(all_val_labels))

      scheduler.step(val_loss)
      xm.rendezvous("lr_scheduler_sync")

      if val_loss < best_val_loss:
          best_val_loss = val_loss
          best_model_state = model.state_dict()
          epochs_no_improve = 0

          if xm.is_master_ordinal():
              xm.save(best_model_state, "best_model.pth")
      else:
          epochs_no_improve += 1

      if epochs_no_improve > patience:
          print("Early stopping activated.")
          break


      xm.rendezvous("cleanup")
    return training_history


## Train

In [None]:
model = ModifiedResNet50Backbone()
model.load_state_dict(torch.load("/content/drive/MyDrive/training_historyBackbone10+22.pth"))


training_history = train_model(model, train_dataloader, val_dataloader)
with open("training_history.pkl", "wb") as f:
                  pickle.dump(training_history, f)