In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import random

In [12]:
# Define dataset class
class LensDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ["no", "sphere", "vort"]
        self.image_paths = []
        self.labels = []
        
        # Load file paths and labels
        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for file in os.listdir(class_dir):
                if file.endswith('.npy'):
                    self.image_paths.append(os.path.join(class_dir, file))
                    self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        
        image = np.load(image_path)
        
        # Ensure the image has 2D or 3D shape compatible with PIL
        if image.ndim == 2:  # Grayscale image
            image = np.stack([image] * 3, axis=-1)  # Convert to RGB by duplicating channels
        elif image.ndim == 3 and image.shape[0] == 1:  # Single-channel image
            image = np.squeeze(image, axis=0)  # Remove the channel dimension
            image = np.stack([image] * 3, axis=-1)  # Convert to RGB
        
        image = Image.fromarray((image * 255).astype(np.uint8))  # Scale to 0-255 and convert to uint8
        print(image.size)  # Print size instead of shape for PIL images
            
    
        if self.transform:
            image = self.transform(image)
        print(image.shape)
        return image, label

In [13]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for DenseNet/ResNet
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# Load train and validation datasets
train_dataset = LensDataset(r"D:\my_study\GSOC\dataset\dataset\train", transform=transform)
val_dataset = LensDataset(r"D:\my_study\GSOC\dataset\dataset\val", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [14]:
for i in iter(train_loader):
    print(i[0].shape)
    break

(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(150, 150)
torch.Size([3, 224, 224])
(

In [16]:
for root, dirs, files in os.walk('D:\my_study\GSOC\dataset'):
    level = root.replace('dataset', '').count(os.sep)
    indent = ' ' * 4 * level
    print(f"{indent}{os.path.basename(root)}/")
    sub_indent = ' ' * 4 * (level + 1)
    # for f in files:
    #     print(f"{sub_indent}{f}")

            dataset/
                dataset/
                    train/
                        no/
                        sphere/
                        vort/
                    val/
                        no/
                        sphere/
                        vort/
                __MACOSX/
                    dataset/
                        train/
                        val/
