In [3]:
# Import stuff
from torchvision import transforms, datasets
import os
from PIL import Image
import numpy as np
import platform
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torch.utils.data import DataLoader, random_split
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from ultralytics import YOLO
import shutil
import cv2

In [4]:
# Select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device.type}\n")

Using device: mps



# Data Preparation

In [31]:
# Check the sizes of the images and if any are in RGBA format
root = 'brain-mri'
splits = [dir for dir in os.listdir(root) if dir != '.DS_Store']
categories = [dir for dir in os.listdir(os.path.join(root, splits[0])) if dir != '.DS_Store']
img_size = (512, 512)
count = 0
non_square = 0
non_RGB_count = 0
for split in splits:
    for category in categories:
        path = os.path.join(root, split, category, 'images')
        for img_name in os.listdir(path):
            if not img_name.endswith('.jpg'):
                continue
            img_path = os.path.join(path, img_name)
            img = Image.open(img_path)
            if img.mode != 'RGB':
                non_RGB_count += 1
            if img.size != img_size:
                #print(f"Image {img_path} has size {img.size}")
                if img.size[0] != img.size[1]:
                    #print(f"Image {img_path} is not square: {img.size}")
                    non_square += 1
                count += 1
print(f"Total images with incorrect size: {count}")
print(f"Total non-square images: {non_square}")
print(f"Total non-RGB images: {non_RGB_count}")

Total images with incorrect size: 1043
Total non-square images: 697
Total non-RGB images: 1002


In [None]:
# Resize images to 512x512 and adjust labels for non-square images
resized_dir = 'resized-data'
shutil.rmtree(resized_dir, ignore_errors=True)
os.makedirs(resized_dir, exist_ok=True)

for split in splits:
    # Make the new directories
    for category in categories:
        resized_img_path = os.path.join(resized_dir, split, category, 'images')
        resized_label_path = os.path.join(resized_dir, split, category, 'labels')
        os.makedirs(resized_img_path, exist_ok=True)
        os.makedirs(resized_label_path, exist_ok=True)
    
# Resize images and labels
for split in splits:
    for category in categories:
        img_path = os.path.join(root, split, category, 'images')
        label_path = os.path.join(root, split, category, 'labels')
        for img_name in os.listdir(img_path):
            # Skip .DS_Store and whatever else is not a .jpg image
            if not img_name.endswith('.jpg'):
                continue

            # Skip images with no labels
            label_file = img_name.replace('.jpg', '.txt')
            if not os.path.exists(os.path.join(label_path, label_file)):
                continue

            # Skip 512x512 images
            img_full_path = os.path.join(img_path, img_name)
            img = Image.open(img_full_path)

            # Convert RGBA to RGB if necessary
            if img.mode != 'RGB':
                img = img.convert('RGB')

            if img.size == img_size:
                # Copy the image and label as is
                shutil.copy(img_full_path, os.path.join(resized_dir, split, category, 'images', img_name))
                label_full_path = os.path.join(label_path, img_name.replace('.jpg', '.txt'))
                shutil.copy(label_full_path, os.path.join(resized_dir, split, category, 'labels', img_name.replace('.jpg', '.txt')))
            
            # Resize non-512x512 square images
            elif img.size[0] == img.size[1]:
                # Resize the image and copy to new dir
                img_resized = img.resize(img_size)
                img_resized.save(os.path.join(resized_dir, split, category, 'images', img_name))
                # Copy the label as is
                label_full_path = os.path.join(label_path, img_name.replace('.jpg', '.txt'))
                shutil.copy(label_full_path, os.path.join(resized_dir, split, category, 'labels', img_name.replace('.jpg', '.txt')))
            
            # Resize non-square images and adjust labels
            else:
                original_width, original_height = img.size
                img_resized = img.resize(img_size)
                img_resized.save(os.path.join(resized_dir, split, category, 'images', img_name))
                
                # Adjust the label
                label_full_path = os.path.join(label_path, img_name.replace('.jpg', '.txt'))
                with open(label_full_path, 'r') as f:
                    lines = f.readlines()
                adjusted_lines = []
                for line in lines:
                    components = line.strip().split()
                    class_id = components[0]
                    x_center = float(components[1])
                    y_center = float(components[2])
                    width = float(components[3])
                    height = float(components[4])
                    
                    # Adjust x_center and width
                    x_center_adj = x_center * (original_width / 512)
                    width_adj = width * (original_width / 512)
                    
                    # Adjust y_center and height
                    y_center_adj = y_center * (original_height / 512)
                    height_adj = height * (original_height / 512)
                    adjusted_line = f"{class_id} {x_center_adj:.6f} {y_center_adj:.6f} {width_adj:.6f} {height_adj:.6f}\n"
                    adjusted_lines.append(adjusted_line)
                
                # Write the adjusted label
                with open(os.path.join(resized_dir, split, category, 'labels', img_name.replace('.jpg', '.txt')), 'w') as f:
                    f.writelines(adjusted_lines)

# 11 seconds on laptop