In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from transformers import ViTForImageClassification, ViTConfig
import pandas as pd
import numpy as np
from PIL import Image
import os
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MedicalImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Append .tif extension to the image name
        img_name = os.path.join(self.img_dir, self.data.iloc[idx]['id'] + '.tif')
        image = Image.open(img_name)
        label = self.data.iloc[idx]['label']

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
class TestDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.image_ids = self.data['id'].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Append .tif extension to the image name
        img_name = os.path.join(self.img_dir, self.image_ids[idx] + '.tif')
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image, self.image_ids[idx]

In [4]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader, desc='Training',mininterval=2.0):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


In [5]:
def evaluate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc='Evaluating',mininterval=2.0):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images).logits
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            probs = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())

    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total
    return val_loss, val_acc, all_preds, all_labels, all_probs


In [6]:
def plot_metrics(train_losses, train_accs, val_losses, val_accs):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

In [7]:
def plot_confusion_matrix(cm):
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    
    classes = ['Class 0', 'Class 1']
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()

In [8]:
def plot_roc_curve(fpr, tpr, roc_auc):
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.show()

In [9]:

def predict():
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Data transforms - must match training transforms
    transform = transforms.Compose([
        transforms.Resize((96, 96)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load test dataset
    test_dataset = TestDataset(
        csv_file='sample_submission.csv',
        img_dir='test',
        transform=transform
    )
    
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    # Initialize and load trained model
    config = ViTConfig(
        image_size=96,
        patch_size=8,
        num_channels=3,
        num_classes=2,
        num_hidden_layers=6,
        hidden_size=384,
        num_attention_heads=6,
        intermediate_size=1536
    )
    model = ViTForImageClassification(config).to(device)
    
    # Load the best model weights
    model.load_state_dict(torch.load('best_model_vit3.pth'))
    model.eval()
    
    # Dictionary to store predictions
    predictions = {}
    
    # Predict
    print("Making predictions...")
    with torch.no_grad():
        for images, image_ids in tqdm(test_loader,mininterval=2.0):
            images = images.to(device)
            outputs = model(images).logits
            _, predicted = outputs.max(1)
            
            # Store predictions
            for idx, image_id in enumerate(image_ids):
                predictions[image_id] = predicted[idx].item()
    
    # Load submission file and update predictions
    submission_df = pd.read_csv('sample_submission.csv')
    submission_df['label'] = submission_df['id'].map(predictions)
    
    # Save predictions
    submission_df.to_csv('sample_submission_vit_small_03.csv', index=False)
    print("Predictions saved to sample_submission_vit_small_03.csv")

In [10]:
try:
    os.chdir('prueba_us')
except:
    pass

In [11]:
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 2e-4
VAL_SPLIT = 0.1

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transforms
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
dataset = MedicalImageDataset(
    csv_file='train.csv',
    img_dir='train',
    transform=transform
)

In [12]:
# Split dataset
val_size = int(VAL_SPLIT * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [13]:
# Initialize model
config = ViTConfig(
    image_size=96,
    patch_size=8,
    num_channels=3,
    num_classes=2,
    num_hidden_layers=6,
    hidden_size=384,
    num_attention_heads=6,
    intermediate_size=1536,
    hidden_dropout_prob=0.3,  # Agregar dropout en las capas ocultas
    attention_probs_dropout_prob=0.3  # Agregar dropout en la atención
)
model = ViTForImageClassification(config).to(device)

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# Training tracking
train_losses = []
train_accs = []
val_losses = []
val_accs = []
best_val_acc = 0

# Training loop
for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Validate
    val_loss, val_acc, val_preds, val_labels, val_probs = evaluate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    print(f'Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f} Val Acc: {val_acc:.2f}%')
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model_vit3.pth')

# Plot training curves
plot_metrics(train_losses, train_accs, val_losses, val_accs)

# Calculate final metrics
precision = precision_score(val_labels, val_preds)
recall = recall_score(val_labels, val_preds)
f1 = f1_score(val_labels, val_preds)
cm = confusion_matrix(val_labels, val_preds)

# ROC curve
fpr, tpr, _ = roc_curve(val_labels, val_probs)
roc_auc = auc(fpr, tpr)

# Print final metrics
print('\nFinal Metrics:')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'ROC AUC: {roc_auc:.4f}')

# Plot confusion matrix and ROC curve
plot_confusion_matrix(cm)
plot_roc_curve(fpr, tpr, roc_auc)


Epoch 1/50


Training: 100%|██████████| 5570/5570 [30:05<00:00,  3.08it/s]
Evaluating: 100%|██████████| 619/619 [03:52<00:00,  2.66it/s]


Train Loss: 0.3866 Train Acc: 82.69%
Val Loss: 0.5805 Val Acc: 75.31%

Epoch 2/50


Training: 100%|██████████| 5570/5570 [30:30<00:00,  3.04it/s]
Evaluating: 100%|██████████| 619/619 [03:11<00:00,  3.23it/s]


Train Loss: 0.3439 Train Acc: 85.05%
Val Loss: 0.3763 Val Acc: 83.22%

Epoch 3/50


Training: 100%|██████████| 5570/5570 [26:56<00:00,  3.45it/s]
Evaluating: 100%|██████████| 619/619 [01:14<00:00,  8.28it/s]


Train Loss: 0.3224 Train Acc: 86.05%
Val Loss: 0.3297 Val Acc: 85.62%

Epoch 4/50


Training: 100%|██████████| 5570/5570 [12:32<00:00,  7.40it/s]
Evaluating: 100%|██████████| 619/619 [01:16<00:00,  8.12it/s]


Train Loss: 0.3060 Train Acc: 86.91%
Val Loss: 0.3690 Val Acc: 83.53%

Epoch 5/50


Training: 100%|██████████| 5570/5570 [12:34<00:00,  7.38it/s]
Evaluating: 100%|██████████| 619/619 [01:17<00:00,  8.01it/s]


Train Loss: 0.2908 Train Acc: 87.63%
Val Loss: 0.3481 Val Acc: 86.08%

Epoch 6/50


Training: 100%|██████████| 5570/5570 [12:46<00:00,  7.27it/s]
Evaluating: 100%|██████████| 619/619 [01:16<00:00,  8.06it/s]


Train Loss: 0.2767 Train Acc: 88.41%
Val Loss: 0.3158 Val Acc: 86.90%

Epoch 7/50


Training: 100%|██████████| 5570/5570 [12:39<00:00,  7.33it/s]
Evaluating: 100%|██████████| 619/619 [01:18<00:00,  7.90it/s]


Train Loss: 0.2639 Train Acc: 88.98%
Val Loss: 0.2531 Val Acc: 89.59%

Epoch 8/50


Training: 100%|██████████| 5570/5570 [12:23<00:00,  7.50it/s]
Evaluating: 100%|██████████| 619/619 [01:03<00:00,  9.73it/s]


Train Loss: 0.2530 Train Acc: 89.52%
Val Loss: 0.2606 Val Acc: 89.31%

Epoch 9/50


Training: 100%|██████████| 5570/5570 [11:16<00:00,  8.24it/s]
Evaluating: 100%|██████████| 619/619 [01:04<00:00,  9.57it/s]


Train Loss: 0.2426 Train Acc: 90.04%
Val Loss: 0.2535 Val Acc: 89.80%

Epoch 10/50


Training: 100%|██████████| 5570/5570 [11:24<00:00,  8.13it/s]
Evaluating: 100%|██████████| 619/619 [01:05<00:00,  9.51it/s]


Train Loss: 0.2317 Train Acc: 90.53%
Val Loss: 0.2404 Val Acc: 90.38%

Epoch 11/50


Training: 100%|██████████| 5570/5570 [11:19<00:00,  8.20it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.30it/s]


Train Loss: 0.2202 Train Acc: 91.07%
Val Loss: 0.3768 Val Acc: 84.28%

Epoch 12/50


Training: 100%|██████████| 5570/5570 [11:27<00:00,  8.11it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.32it/s]


Train Loss: 0.2113 Train Acc: 91.46%
Val Loss: 0.2428 Val Acc: 90.16%

Epoch 13/50


Training: 100%|██████████| 5570/5570 [11:22<00:00,  8.16it/s]
Evaluating: 100%|██████████| 619/619 [01:04<00:00,  9.61it/s]


Train Loss: 0.2016 Train Acc: 91.96%
Val Loss: 0.2119 Val Acc: 91.42%

Epoch 14/50


Training: 100%|██████████| 5570/5570 [11:20<00:00,  8.18it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.31it/s]


Train Loss: 0.1924 Train Acc: 92.31%
Val Loss: 0.2152 Val Acc: 91.39%

Epoch 15/50


Training: 100%|██████████| 5570/5570 [11:25<00:00,  8.12it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.35it/s]


Train Loss: 0.1872 Train Acc: 92.62%
Val Loss: 0.2153 Val Acc: 91.30%

Epoch 16/50


Training: 100%|██████████| 5570/5570 [11:18<00:00,  8.21it/s]
Evaluating: 100%|██████████| 619/619 [01:05<00:00,  9.52it/s]


Train Loss: 0.1774 Train Acc: 93.02%
Val Loss: 0.2671 Val Acc: 89.78%

Epoch 17/50


Training: 100%|██████████| 5570/5570 [11:24<00:00,  8.14it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.32it/s]


Train Loss: 0.1703 Train Acc: 93.36%
Val Loss: 0.2386 Val Acc: 90.95%

Epoch 18/50


Training: 100%|██████████| 5570/5570 [11:19<00:00,  8.19it/s]
Evaluating: 100%|██████████| 619/619 [01:05<00:00,  9.44it/s]


Train Loss: 0.1668 Train Acc: 93.46%
Val Loss: 0.2921 Val Acc: 88.97%

Epoch 19/50


Training: 100%|██████████| 5570/5570 [11:26<00:00,  8.12it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.34it/s]


Train Loss: 0.1582 Train Acc: 93.84%
Val Loss: 0.2362 Val Acc: 90.95%

Epoch 20/50


Training: 100%|██████████| 5570/5570 [11:25<00:00,  8.12it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.32it/s]


Train Loss: 0.1527 Train Acc: 94.08%
Val Loss: 0.1998 Val Acc: 92.18%

Epoch 21/50


Training: 100%|██████████| 5570/5570 [11:27<00:00,  8.11it/s]
Evaluating: 100%|██████████| 619/619 [01:05<00:00,  9.43it/s]


Train Loss: 0.1475 Train Acc: 94.31%
Val Loss: 0.2266 Val Acc: 91.05%

Epoch 22/50


Training:  56%|█████▌    | 3132/5570 [06:25<05:01,  8.09it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Training: 100%|██████████| 5570/5570 [11:25<00:00,  8.13it/s]
Evaluating: 100%|██████████| 619/619 [01:03<00:00,  9.71it/s]


Train Loss: 0.0896 Train Acc: 96.62%
Val Loss: 0.2162 Val Acc: 92.47%

Epoch 37/50


Training: 100%|██████████| 5570/5570 [11:31<00:00,  8.05it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.32it/s]


Train Loss: 0.0866 Train Acc: 96.71%
Val Loss: 0.1911 Val Acc: 93.50%

Epoch 38/50


Training: 100%|██████████| 5570/5570 [11:40<00:00,  7.95it/s]
Evaluating: 100%|██████████| 619/619 [01:06<00:00,  9.29it/s]


Train Loss: 0.0841 Train Acc: 96.83%
Val Loss: 0.2265 Val Acc: 93.02%

Epoch 39/50


Training: 100%|██████████| 5570/5570 [13:37<00:00,  6.81it/s]
Evaluating: 100%|██████████| 619/619 [03:08<00:00,  3.28it/s]


Train Loss: 0.0819 Train Acc: 96.93%
Val Loss: 0.1989 Val Acc: 93.39%

Epoch 40/50


Training: 100%|██████████| 5570/5570 [15:50<00:00,  5.86it/s]
Evaluating: 100%|██████████| 619/619 [01:10<00:00,  8.80it/s]


Train Loss: 0.0804 Train Acc: 96.97%
Val Loss: 0.2105 Val Acc: 93.39%

Epoch 41/50


Training: 100%|██████████| 5570/5570 [11:46<00:00,  7.88it/s]
Evaluating: 100%|██████████| 619/619 [01:10<00:00,  8.83it/s]


Train Loss: 0.0782 Train Acc: 97.09%
Val Loss: 0.1952 Val Acc: 93.33%

Epoch 42/50


Training: 100%|██████████| 5570/5570 [11:43<00:00,  7.92it/s]
Evaluating: 100%|██████████| 619/619 [01:08<00:00,  8.98it/s]


Train Loss: 0.0759 Train Acc: 97.14%
Val Loss: 0.2344 Val Acc: 92.48%

Epoch 43/50


Training: 100%|██████████| 5570/5570 [19:45<00:00,  4.70it/s]
Evaluating: 100%|██████████| 619/619 [02:58<00:00,  3.47it/s]


Train Loss: 0.0741 Train Acc: 97.25%
Val Loss: 0.2307 Val Acc: 93.06%

Epoch 44/50


Training: 100%|██████████| 5570/5570 [22:34<00:00,  4.11it/s]
Evaluating: 100%|██████████| 619/619 [01:11<00:00,  8.62it/s]


Train Loss: 0.0717 Train Acc: 97.33%
Val Loss: 0.2031 Val Acc: 93.57%

Epoch 45/50


Training: 100%|██████████| 5570/5570 [11:15<00:00,  8.25it/s]
Evaluating: 100%|██████████| 619/619 [01:05<00:00,  9.44it/s]


Train Loss: 0.0701 Train Acc: 97.39%
Val Loss: 0.2156 Val Acc: 93.26%

Epoch 46/50


Training: 100%|██████████| 5570/5570 [11:06<00:00,  8.36it/s]
Evaluating: 100%|██████████| 619/619 [01:04<00:00,  9.61it/s]


Train Loss: 0.0689 Train Acc: 97.42%
Val Loss: 0.2644 Val Acc: 92.28%

Epoch 47/50


Training: 100%|██████████| 5570/5570 [11:04<00:00,  8.39it/s]
Evaluating: 100%|██████████| 619/619 [01:01<00:00, 10.02it/s]


Train Loss: 0.0659 Train Acc: 97.54%
Val Loss: 0.2228 Val Acc: 93.40%

Epoch 48/50


Training: 100%|██████████| 5570/5570 [11:03<00:00,  8.39it/s]
Evaluating: 100%|██████████| 619/619 [01:03<00:00,  9.77it/s]


Train Loss: 0.0662 Train Acc: 97.50%
Val Loss: 0.2820 Val Acc: 91.48%

Epoch 49/50


Training: 100%|██████████| 5570/5570 [11:05<00:00,  8.37it/s]
Evaluating: 100%|██████████| 619/619 [01:04<00:00,  9.59it/s]


Train Loss: 0.0641 Train Acc: 97.61%
Val Loss: 0.2274 Val Acc: 93.67%

Epoch 50/50


Training:   7%|▋         | 408/5570 [00:49<10:26,  8.24it/s]

In [None]:
predict()