In [92]:
import os
import torchvision.models as models 
import torch.nn as nn
import torch
import pandas as pd
from torchvision.io import read_image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import torchvision
from glob import glob
from torchinfo import summary
import numpy as np
import torch.functional as F
import torchvision.transforms as T
from tqdm import tqdm_notebook
device0 = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_load

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_glob, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_list = glob(img_glob)
        self.transform = transform
        self.target_transform = target_transform
        self.file_name_list=[]
        self.file_remove_list=[]
        self.label_list=[]
        
        for i in range(len(self.img_list)):
            self.file_name_list.append(os.path.splitext(os.path.basename(self.img_list[i]))[0][:9])

        for i in range(len(self.file_name_list)):
            label_index=self.img_labels.loc[self.img_labels['PathologyNumber'] == self.file_name_list[i]]
            if len(label_index)==0:
                self.file_remove_list.append(self.img_list[i])
            else:
                if label_index['MMR status'].to_list()[0]=='Normal':
                    self.label_list.append(0)
                else:
                    self.label_list.append(1)
        for i in range(len(self.file_remove_list)):
            self.img_list.remove(self.file_remove_list[i])
        self.image_x5=[f.replace('/2.5x_standard', '/5x_standard') for f in self.img_list]
        self.image_x10=[f.replace('/2.5x_standard', '/10x_standard') for f in self.img_list]
        img_5x_temp=[]
        img_10x_temp=[]
        for i in range(len(self.img_list)):
            for j in range(4):
                img_5x_temp.append(self.image_x5[i][:-4]+'_'+str(j)+'.jpg')
            for j in range(16):
                img_10x_temp.append(self.image_x10[i][:-4]+'_'+str(j)+'.jpg')
            self.image_x5[i]=img_5x_temp
            self.image_x10[i]=img_10x_temp
        self.transform = T.Resize(224)
    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        image = (1-self.transform(read_image(self.img_list[idx])))/255
        label = self.label_list[idx]
        image_5x=torch.zeros(4,3,224,224)
        image_10x=torch.zeros(16,3,224,224)
        for i in range(4):
            image_5x[i]=self.transform(read_image(self.image_x5[idx][i]))
        for i in range(16):
            image_10x[i]=self.transform(read_image(self.image_x10[idx][i]))
        image_5x=np.transpose(1-image_5x, axes=(1, 0, 2,3))/255
        image_10x=np.transpose(1-image_10x, axes=(1, 0, 2,3))/255
        image_5x=torch.reshape(1-image_5x,(3*4,224,224))
        image_10x=torch.reshape(1-image_10x,(16*3,224,224))
        
        return image,image_5x,image_10x,label

In [None]:
img_glob='../../data/CycleGANData/3divisionTile/2.5x_standard/*.jpg'
annotations_file='../../data/OriginalData/MMR.csv'
dataset=CustomImageDataset(annotations_file,img_glob)
dataset_size = len(dataset)
train_size = int(dataset_size * 0.8)
validation_size = int(dataset_size * 0.1)
test_size = dataset_size - train_size - validation_size
train_dataset, validation_dataset, test_dataset = random_split(dataset, [train_size, validation_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=16, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, drop_last=True)

model

In [236]:
class CNN1D_1(nn.Module):
    def __init__(self):
        super(CNN1D_1,self).__init__()
        self.layer1 = torch.nn.Sequential(
        torch.nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3),
        torch.nn.MaxPool1d(kernel_size=2,stride=2),
        torch.nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3),
        torch.nn.MaxPool1d(kernel_size=2,stride=2) 
        ,
        torch.nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3),
        torch.nn.MaxPool1d(kernel_size=2,stride=2) 
        ,nn.LSTM(input_size=123, hidden_size=64,bidirectional=True)
        )
    def forward(self,x):
        x=self.layer1(x)
        return x

class CNN1D_2(nn.Module):
    def __init__(self):
        super(CNN1D_2,self).__init__()
        self.conv1d=torch.nn.Sequential(torch.nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3),
                                        torch.nn.MaxPool1d(kernel_size=2,stride=2),
                                        nn.Flatten(start_dim=1,end_dim=-1),
                                        nn.LSTM(input_size=4032, hidden_size=64),
                                        nn.Linear(in_features = 64, out_features = 256),
                                        nn.LeakyReLU(),
                                        nn.Dropout1d(0.5),
                                        nn.Linear(in_features = 256, out_features = 128),
                                        nn.LeakyReLU(),
                                        nn.Dropout1d(0.5),
                                        nn.Linear(in_features = 256, out_features = 1))
    def forward(self,x):
        x=self.conv1d(x)
        return x
    
class CNN1D(nn.Module):
    def __init__(self):
        super(CNN1D,self).__init__()
        self.cnn1d_1=CNN1D_1().to(device0)
        self.cnn1d_2=CNN1D_2().to(device0)
    def forward(self,x):
        x=self.cnn1d_1(x)
        x=self.cnn1d_2(x)
        return x
    
model = CNN1D().to(device0)
criterion = nn.BCEWithLogitsLoss().to(device0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [CNN1D_1: 1, Sequential: 2, Conv1d: 3, MaxPool1d: 3, Conv1d: 3, MaxPool1d: 3, Conv1d: 3, MaxPool1d: 3, LSTM: 3]