In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import time

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import OneCycleLR

# dataset
class BreakHisDataset(Dataset):
  def __init__(self, csv_file, root_dir, train=True, transform=None):
    
    self.data_frame = pd.read_csv(csv_file)
    self.root_dir = root_dir
    self.transform = transform

    if train:
      self.data_frame = self.data_frame[self.data_frame['grp'].str.lower() == "train"]
    else:
      self.data_frame = self.data_frame[self.data_frame['grp'].str.lower() == "test"]
    
    self.data_frame.reset_index(drop=True, inplace=True)

  def __len__(self):
    return len(self.data_frame)

  def __getitem__(self, idx):
    row = self.data_frame.iloc[idx]
    filename = row['filename']
    
    img_path = os.path.join(self.root_dir, filename)
    image = Image.open(img_path).convert('RGB')
  
    lower_filename = filename.lower()
    if "adenosis" in lower_filename:
      label = 0
    elif "fibroadenoma" in lower_filename:
      label = 1
    elif "phyllodes_tumor" in lower_filename:
      label = 2
    elif "tubular_adenoma" in lower_filename:
      label = 3
    elif "ductal_carcinoma" in lower_filename:
      label = 4
    elif "lobular_carcinoma" in lower_filename:
      label = 5
    elif "mucinous_carcinoma" in lower_filename:
      label = 6
    elif "papillary_carcinoma" in lower_filename:
      label = 7
    else:
      raise ValueError(f"Cannot determine label from filename: {filename}")

    if self.transform:
      image = self.transform(image)
    return image, label

def calculate_mean_std(dataset, batch_size=64, num_workers=4):
  loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
  mean = 0.0
  std = 0.0
  nb_samples = 0

  for data, _ in loader:
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)
    mean += data.mean(2).sum(0)
    std += data.std(2).sum(0)
    nb_samples += batch_samples

  mean /= nb_samples
  std /= nb_samples
  return mean, std

img_size = (256, 256)
csv_file = "Folds.csv"
root_dir = "./../BreaKHis_v1/"

train_transform_for_stats = T.Compose([
    T.Resize(img_size),
    T.ToTensor()
])
stats_dataset = BreakHisDataset(csv_file=csv_file, root_dir=root_dir, train=True, transform=train_transform_for_stats)
mean, std = calculate_mean_std(stats_dataset)
print("Mean:", mean)
print("Std:", std)




Mean: tensor([0.7872, 0.6222, 0.7640])
Std: tensor([0.1005, 0.1330, 0.0837])
