In [1]:
import os
from collections import Counter

import evaluate
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from transformers import AutoImageProcessor, ResNetForImageClassification
import tkinter as tk
from tkinter import filedialog, Label, Button

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

In [2]:
# Define a custom dataset
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, label, transform=None):
        self.image_dir = image_dir
        self.label = label
        self.transform = transform
        self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.label

In [3]:
# Define transformation with data augmentation
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [4]:
root = 'C:/Users/nithi/Desktop/FAU/Semester-4/Master_Thesis_Federated_Learning/Dataset/sample/images'

# Create datasets
train_atelectasis = CustomImageDataset(root + '/Atelectasis/train', 0, train_transform)
train_cardiomegaly = CustomImageDataset(root + '/Cardiomegaly/train', 1, train_transform)
val_atelectasis = CustomImageDataset(root + '/Atelectasis/val', 0, val_test_transform)
val_cardiomegaly = CustomImageDataset(root + '/Cardiomegaly/val', 1, val_test_transform)
test_atelectasis = CustomImageDataset(root + '/Atelectasis/test', 0, val_test_transform)
test_cardiomegaly = CustomImageDataset(root + '/Cardiomegaly/test', 1, val_test_transform)

In [5]:
# Combine train datasets and calculate class weights
train_dataset = train_atelectasis + train_cardiomegaly
class_counts = Counter([label for _, label in train_dataset])
class_weights = {0: 1.0 / class_counts[0], 1: 1.0 / class_counts[1]}
sample_weights = [class_weights[label] for _, label in train_dataset]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

val_dataset = val_atelectasis + val_cardiomegaly
test_dataset = test_atelectasis + test_cardiomegaly

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [6]:
# Load pre-trained model and processor
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50", num_labels=2, ignore_mismatched_sizes=True)
model.config.id2label = {0: 'Atelectasis', 1: 'Cardiomegaly'}
model.config.label2id = {'Atelectasis': 0, 'Cardiomegaly': 1}

# Add dropout for regularization
model.classifier.dropout = torch.nn.Dropout(p=0.5)

# Define training parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor([class_weights[0], class_weights[1]]).to(device))

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([2, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# Training loop with early stopping
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10, patience=3):
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(pixel_values=images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(pixel_values=images).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss}, Val Loss: {val_loss}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                break

In [8]:
# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10)

Epoch 1/10, Train Loss: 0.6199410023360417, Val Loss: 0.9108462631702423
Epoch 2/10, Train Loss: 0.5802430644117552, Val Loss: 1.0615930344377245
Epoch 3/10, Train Loss: 0.5157384379156704, Val Loss: 1.2677611815077918
Epoch 4/10, Train Loss: 0.4812785782690706, Val Loss: 0.6747127260480609
Epoch 5/10, Train Loss: 0.4943594254296401, Val Loss: 1.1991265820605415
Epoch 6/10, Train Loss: 0.4770188023304117, Val Loss: 0.8957285971513816
Epoch 7/10, Train Loss: 0.4847736492239196, Val Loss: 0.7801030831677573
Early stopping


In [9]:
# Load the best model
model.load_state_dict(torch.load('best_model.pt'))

<All keys matched successfully>

In [10]:
# Evaluate the model
metric = evaluate.load("accuracy", trust_remote_code=True)
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(pixel_values=images).logits
        predictions = torch.argmax(outputs, dim=-1)
        metric.add_batch(predictions=predictions, references=labels)

accuracy = metric.compute()['accuracy']
print(f"Test Accuracy: {accuracy}")

Test Accuracy: 0.5858585858585859


In [11]:
# Predict using the trained model
def predict_image(image_path, model, processor):
    if not os.path.exists(image_path):
        print(f"Image path {image_path} does not exist.")
        return None
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**inputs).logits
    predicted_label = logits.argmax(-1).item()
    return model.config.id2label[predicted_label]


In [13]:
# GUI to select an image and predict
def open_file():
    file_path = filedialog.askopenfilename()
    if file_path:
        label.config(text=f"Selected image: {file_path}")
        predicted_label = predict_image(file_path, model, processor)
        result_label.config(text=f"Predicted label: {predicted_label}")
        
# Create the GUI application
root = tk.Tk()
root.title("Image Classification")

label = Label(root, text="Select an image to classify")
label.pack(pady=10)

button = Button(root, text="Select Image", command=open_file)
button.pack(pady=10)

result_label = Label(root, text="")
result_label.pack(pady=10)

root.mainloop()