In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import cv2
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from warnings import filterwarnings
filterwarnings("ignore")

In [2]:
try:
    # if use GPU, use it
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {torch.cuda.get_device_name()} for training." if torch.cuda.is_available() else "Using CPU for training")
except:
    print(f"No GPU found. Using CPU for training.")
    device = torch.device("cpu")

Using NVIDIA GeForce RTX 3060 Ti for training.


In [25]:
dataset_path = 'Dataset/'
class_names = sorted(os.listdir(dataset_path))
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
batch_size = 32
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [26]:
class_names

['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']

In [27]:
# Processing images and labels in each folder
for batch in data_loader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)

In [28]:
total_images, total_labels = 0, 0
for batch in data_loader:
    images, labels = batch
    total_images += images.size(0)
    total_labels += len(labels)

In [29]:
print(f"Total images: {total_images}", f"\nTotal labels: {total_labels}")    

Total images: 10287 
Total labels: 10287


In [30]:
train_images, test_images, train_labels, test_labels = train_test_split(images, labels, test_size=0.15, random_state=42)
train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.15, random_state=42)

dataset_list = [train_images, val_images, test_images, train_labels, val_labels, test_labels]

for i in dataset_list:
    print(i.shape)

torch.Size([10, 3, 224, 224])
torch.Size([2, 3, 224, 224])
torch.Size([3, 3, 224, 224])
torch.Size([10])
torch.Size([2])
torch.Size([3])
