In [1]:
from collections import OrderedDict

import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# from torchvision.datasets import CIFAR10
import flwr as fl
import importlib
import os

from src.models import nets
from src.data_loader import ALLDataset
from tqdm import tqdm
import numpy as np
from pathlib import Path
import pickle

Initiating SANITY CHECK.


In [11]:
CSV_PATH = os.environ['csv_path']
DATASET_PATH = os.environ['dataset_path']
DATA_LOADER_TYPE= os.getenv('data_loader_type',"optimam")
config_file = 'config.yaml'
with open(config_file) as file:
  CONFIG = yaml.safe_load(file)

In [12]:
def import_class(name):
    module_name, class_name = name.rsplit('.', 1)
    module = importlib.import_module(module_name)
    return getattr(module, class_name)
print(f'Here dataset path {DATASET_PATH}')
print(f'Here csv path {CSV_PATH}')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CRITERION = import_class(CONFIG['hyperparameters']['criterion'])

# training_loader = DataLoader(ALLDataset(DATASET_PATH, CSV_PATH, 'train'), batch_size=CONFIG['hyperparameters']['batch_size'])
training_loader = DataLoader(ALLDataset(DATASET_PATH, CSV_PATH, mode='train', data_loader_type=DATA_LOADER_TYPE, load_max=CONFIG['data']['load_max']), batch_size=CONFIG['hyperparameters']['batch_size'])


Here dataset path /home/lidia-garrucho/datasets/INBREAST/AllPNG_cropped
Here csv path /home/lidia-garrucho/datasets/INBREAST/INbreast_updated_cropped_breast.csv


100%|██████████| 36/36 [00:00<00:00, 123563.78it/s]


In [20]:
net = nets.SqueezeNetClassifier(in_ch=3, out_ch=1, linear_ch=512, pretrained=False)
net.to(DEVICE)
criterion=CRITERION()
epochs=3
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
losses = []
cumulative_loss = 0.0
predictions = []

Using cache found in /home/akis-linardos/.cache/torch/hub/pytorch_vision_v0.6.0


In [21]:
print('Training...')
for _ in range(epochs):
    for i, batch in enumerate(tqdm(training_loader)):
        images, labels = batch[0].to(DEVICE), batch[1].to(DEVICE).unsqueeze(1)
        optimizer.zero_grad()
        preds = net(images)
        loss = criterion(net(images), labels)
        cumulative_loss += loss.item()
        loss.backward()
        optimizer.step()
        losses.append(loss)

train_results = cumulative_loss #(losses, predictions)

Training...


100%|██████████| 33/33 [00:08<00:00,  3.76it/s]
100%|██████████| 33/33 [00:08<00:00,  3.99it/s]
100%|██████████| 33/33 [00:08<00:00,  4.05it/s]


In [25]:
net = nets.ResNet101Classifier(in_ch=3, out_ch=1, linear_ch=2048, pretrained=False)
net.to(DEVICE)
criterion=CRITERION()
epochs=3
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
losses = []
cumulative_loss = 0.0
predictions = []
print('Training...')
for _ in range(epochs):
    for i, batch in enumerate(tqdm(training_loader)):
        images, labels = batch[0].to(DEVICE), batch[1].to(DEVICE).unsqueeze(1)
        optimizer.zero_grad()
        loss = criterion(net(images), labels)
        cumulative_loss += loss.item()
        loss.backward()
        optimizer.step()

        losses.append(loss)

train_results = cumulative_loss #(losses, predictions)

Using cache found in /home/akis-linardos/.cache/torch/hub/pytorch_vision_v0.10.0


Training...


100%|██████████| 33/33 [00:09<00:00,  3.39it/s]
100%|██████████| 33/33 [00:09<00:00,  3.40it/s]
100%|██████████| 33/33 [00:09<00:00,  3.41it/s]


In [26]:
# CLient does this:
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

NameError: name 'parameters' is not defined

In [23]:
%debug

> [0;32m/home/akis-linardos/.local/lib/python3.6/site-packages/torch/nn/functional.py[0m(1848)[0;36mlinear[0;34m()[0m
[0;32m   1846 [0;31m    [0;32mif[0m [0mhas_torch_function_variadic[0m[0;34m([0m[0minput[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0mbias[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1847 [0;31m        [0;32mreturn[0m [0mhandle_torch_function[0m[0;34m([0m[0mlinear[0m[0;34m,[0m [0;34m([0m[0minput[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0mbias[0m[0;34m)[0m[0;34m,[0m [0minput[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0mbias[0m[0;34m=[0m[0mbias[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1848 [0;31m    [0;32mreturn[0m [0mtorch[0m[0;34m.[0m[0m_C[0m[0;34m.[0m[0m_nn[0m[0;34m.[0m[0mlinear[0m[0;34m([0m[0minput[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0mbias[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1849 [0;31m[0;34m[0m[0m
[0m[0;32m   1850 [0;31m[0;34m

ipdb>  u


> [0;32m/home/akis-linardos/.local/lib/python3.6/site-packages/torch/nn/modules/linear.py[0m(103)[0;36mforward[0;34m()[0m
[0;32m    101 [0;31m[0;34m[0m[0m
[0m[0;32m    102 [0;31m    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput[0m[0;34m:[0m [0mTensor[0m[0;34m)[0m [0;34m->[0m [0mTensor[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 103 [0;31m        [0;32mreturn[0m [0mF[0m[0;34m.[0m[0mlinear[0m[0;34m([0m[0minput[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mweight[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mbias[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    104 [0;31m[0;34m[0m[0m
[0m[0;32m    105 [0;31m    [0;32mdef[0m [0mextra_repr[0m[0;34m([0m[0mself[0m[0;34m)[0m [0;34m->[0m [0mstr[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/akis-linardos/.local/lib/python3.6/site-packages/torch/nn/modules/module.py[0m(1102)[0;36m_call_impl[0;34m()[0m
[0;32m   1100 [0;31m        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
[0m[0;32m   1101 [0;31m                or _global_forward_hooks or _global_forward_pre_hooks):
[0m[0;32m-> 1102 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0minput[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1103 [0;31m        [0;31m# Do not call functions when jit is used[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1104 [0;31m        [0mfull_backward_hooks[0m[0;34m,[0m [0mnon_full_backward_hooks[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m,[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/akis-linardos/.local/lib/python3.6/site-packages/torchvision/models/resnet.py[0m(244)[0;36m_forward_impl[0;34m()[0m
[0;32m    242 [0;31m        [0mx[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mavgpool[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    243 [0;31m        [0mx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mflatten[0m[0;34m([0m[0mx[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 244 [0;31m        [0mx[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mfc[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    245 [0;31m[0;34m[0m[0m
[0m[0;32m    246 [0;31m        [0;32mreturn[0m [0mx[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/akis-linardos/.local/lib/python3.6/site-packages/torchvision/models/resnet.py[0m(249)[0;36mforward[0;34m()[0m
[0;32m    247 [0;31m[0;34m[0m[0m
[0m[0;32m    248 [0;31m    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mx[0m[0;34m:[0m [0mTensor[0m[0;34m)[0m [0;34m->[0m [0mTensor[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 249 [0;31m        [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_forward_impl[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    250 [0;31m[0;34m[0m[0m
[0m[0;32m    251 [0;31m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/akis-linardos/.local/lib/python3.6/site-packages/torch/nn/modules/module.py[0m(1102)[0;36m_call_impl[0;34m()[0m
[0;32m   1100 [0;31m        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
[0m[0;32m   1101 [0;31m                or _global_forward_hooks or _global_forward_pre_hooks):
[0m[0;32m-> 1102 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0minput[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1103 [0;31m        [0;31m# Do not call functions when jit is used[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1104 [0;31m        [0mfull_backward_hooks[0m[0;34m,[0m [0mnon_full_backward_hooks[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m,[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/akis-linardos/.local/lib/python3.6/site-packages/torch/nn/modules/container.py[0m(141)[0;36mforward[0;34m()[0m
[0;32m    139 [0;31m    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    140 [0;31m        [0;32mfor[0m [0mmodule[0m [0;32min[0m [0mself[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 141 [0;31m            [0minput[0m [0;34m=[0m [0mmodule[0m[0;34m([0m[0minput[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    142 [0;31m        [0;32mreturn[0m [0minput[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    143 [0;31m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/akis-linardos/.local/lib/python3.6/site-packages/torch/nn/modules/module.py[0m(1102)[0;36m_call_impl[0;34m()[0m
[0;32m   1100 [0;31m        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
[0m[0;32m   1101 [0;31m                or _global_forward_hooks or _global_forward_pre_hooks):
[0m[0;32m-> 1102 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0minput[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1103 [0;31m        [0;31m# Do not call functions when jit is used[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1104 [0;31m        [0mfull_backward_hooks[0m[0;34m,[0m [0mnon_full_backward_hooks[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m,[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m<ipython-input-22-5ad14b1fdef0>[0m(14)[0;36m<module>[0;34m()[0m
[0;32m     12 [0;31m        [0mimages[0m[0;34m,[0m [0mlabels[0m [0;34m=[0m [0mbatch[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m.[0m[0mto[0m[0;34m([0m[0mDEVICE[0m[0;34m)[0m[0;34m,[0m [0mbatch[0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m.[0m[0mto[0m[0;34m([0m[0mDEVICE[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m        [0moptimizer[0m[0;34m.[0m[0mzero_grad[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 14 [0;31m        [0mloss[0m [0;34m=[0m [0mcriterion[0m[0;34m([0m[0mnet[0m[0;34m([0m[0mimages[0m[0;34m)[0m[0;34m,[0m [0mlabels[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m        [0mcumulative_loss[0m [0;34m+=[0m [0mloss[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [

ipdb>  net(images)


*** RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x2048 and 512x1)


ipdb>  q


In [None]:
train_results

In [None]:
seed = 42  # for reproducibility

# Imports
import os
import yaml
import enum
import copy
import random
import tempfile
import warnings
import multiprocessing
import pickle
import numpy as np
from collections import OrderedDict
import sys
sys.path.append('/BFP')

from math import floor, ceil
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
torch.manual_seed(seed)
from PIL import Image
from src.data_augmentation.breast_density.data.resize_image import *
from src.preprocessing.histogram_standardization import apply_hist_stand_landmarks
from src.data_handling.mmg_detection_datasets import *

from torch.utils.data import BatchSampler, RandomSampler 

    
pathologies = ['mass'] #['mass', 'calcifications', 'suspicious_calcifications', 'architectural_distortion'] # None to select all
# Resize images keeping aspect ratio
rescale_height = 224
rescale_width = 224

image_ctr = 0

def preprocess_one_image_OPTIMAM(image):
    label = np.single(0) if image.status=='Benign' else np.single(1)
    # status = image.status # ['Benign', 'Malignant', 'Interval Cancer', 'Normal']
    manufacturer = image.manufacturer # ['HOLOGIC, Inc.', 'Philips Digital Mammography Sweden AB', 'GE MEDICAL SYSTEMS', 'Philips Medical Systems', 'SIEMENS']
    # view = image.view # MLO_VIEW = ['MLO','LMLO','RMLO', 'LMO', 'ML'] CC_VIEW = ['CC','LCC','RCC', 'XCCL', 'XCCM']
    # laterality = image.laterality # L R

    img_pil = Image.open(image.path).convert('RGB')
    img_np = np.array(img_pil)
    scale_size = (rescale_height, rescale_width)
    img_np = np.uint8(img_np) if img_np.dtype != np.uint8 else img_np.copy()
    rescaled_img, scale_factor = imrescale(img_np, scale_size, return_scale=True, backend='pillow')
    image = torch.from_numpy(rescaled_img).permute(2,0,1)
    
    # Histogram Matching 
    landmarks_values = torch.load(HOME_PATH / CONFIG['paths']['landmarks'])
    apply_hist_stand_landmarks(image, landmarks_values)

    paddedimg = torch.zeros(3,224,224)
    c,h,w = image.shape
    paddedimg[:,-h:,-w:]=image
    return paddedimg, label



In [None]:
mode='train'
load_max=1000
center=None 
subjects = OPTIMAMDataset(CSV_PATH, DATASET_PATH, detection=False, load_max=-1, 
                    cropped_to_breast=True) # we should be able to load any dataset with this

subjects_selected = {}
if center!=None:
    total_subjects = subjects.get_images_by_site(center)
else:
    # General case
    subjects_selected['benign'] = subjects.get_clients_by_status('Benign')[:load_max] #Note that clients means subjects here.
    subjects_selected['malignant'] = subjects.get_clients_by_status('Malignant')[:load_max]
    subjects_selected['normal'] = subjects.get_clients_by_status('Normal')[:load_max]
    total_subjects = subjects_selected['benign'] + subjects_selected['malignant'] + subjects_selected['normal']
random.shuffle(total_subjects) 
# Data Split
training_subjects = total_subjects[:int(len(total_subjects)*0.8)]
validation_subjects = total_subjects[int(len(total_subjects)*0.8):int(len(total_subjects)*0.9)]
test_subjects = total_subjects[int(len(total_subjects)*0.9):]

def extract_images(subjects):
    images=[]
    for subject in tqdm(subjects):
        for study in subject:
            for image in study:
                images.append(image)
    return images

# if mode == 'train':
#     self.images = extract_images(training_subjects)
# elif mode == 'val':
#     self.images = extract_images(validation_subjects) 
# elif mode == 'test':
#     self.images = extract_images(test_subjects)