In [1]:
%load_ext autoreload
%autoreload 1
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}

In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import RepeatedStratifiedKFold
from torch.utils.data import DataLoader
import timm
from timm import optim, scheduler
import torch
from torch import nn
from torch.optim.lr_scheduler import ExponentialLR
from torchvision import transforms as tfm
from sklearn import metrics as skmet

%aimport dataset
%aimport transforms

VideoData = dataset.VideoData
VideoTransforms = transforms.VideoTransforms

In [3]:
views = ['pdaRelatedView', 'pdaView']
modes = ['2d', 'color', 'color_compare']
bs = 16  # batch size for training
num_workers = 5  # number of parallel data loading workers
res = 224 # pixel size along height and width
device = torch.device('cpu')
num_classes = 3

In [4]:
transforms = VideoTransforms(224)
tfms = transforms.get_transforms('test')

In [5]:
df_frame = pd.read_csv('../label_data/20220822_frame.csv')
df_video = pd.read_csv('../label_data/20220822_video.csv')
df_study = pd.read_csv('../label_data/20220822_study.csv')
df_patient_study = pd.read_csv('../label_data/20220822_patient_study.csv')
df_patient = pd.read_csv('../label_data/20220822_patient.csv')

In [6]:
df_patient.columns, df_patient_study.columns, df_study.columns, df_video.columns, df_frame.columns

(Index(['patient_id', 'num_studies'], dtype='object'),
 Index(['patient_id', 'patient_type', 'study', 'num_videos'], dtype='object'),
 Index(['patient_type', 'study', 'num_videos', 'num_frames'], dtype='object'),
 Index(['external_id', 'patient_type', 'num_frames', 'view', 'mode',
        'diagnosis', 'study', 'mp4_path'],
       dtype='object'),
 Index(['patient_type', 'external_id', 'png_path'], dtype='object'))

In [7]:
# create datasets
df = df_patient.merge(df_patient_study).merge(df_study, on=['patient_type', 'study']).merge(df_video, on=['patient_type', 'study']).merge(df_frame, on=['patient_type', 'external_id'])
df.head()

Unnamed: 0,patient_id,num_studies,patient_type,study,num_videos_x,num_videos_y,num_frames_x,external_id,num_frames_y,view,mode,diagnosis,mp4_path,png_path
0,01578f3a19fbf0f2,1,pda,study22,82,33,4825,study22_dicom35,77,nonPDAView,2d,,/mnt/data/pda/superior_views/PDA/study22_dicom...,/mnt/data/pda/model_data/20220822/pda_study22_...
1,01578f3a19fbf0f2,1,pda,study22,82,33,4825,study22_dicom35,77,nonPDAView,2d,,/mnt/data/pda/superior_views/PDA/study22_dicom...,/mnt/data/pda/model_data/20220822/pda_study22_...
2,01578f3a19fbf0f2,1,pda,study22,82,33,4825,study22_dicom35,77,nonPDAView,2d,,/mnt/data/pda/superior_views/PDA/study22_dicom...,/mnt/data/pda/model_data/20220822/pda_study22_...
3,01578f3a19fbf0f2,1,pda,study22,82,33,4825,study22_dicom35,77,nonPDAView,2d,,/mnt/data/pda/superior_views/PDA/study22_dicom...,/mnt/data/pda/model_data/20220822/pda_study22_...
4,01578f3a19fbf0f2,1,pda,study22,82,33,4825,study22_dicom35,77,nonPDAView,2d,,/mnt/data/pda/superior_views/PDA/study22_dicom...,/mnt/data/pda/model_data/20220822/pda_study22_...


# Examine

In [30]:
d = VideoData(df, transforms = tfms, mode_filter = modes, view_filter = views)

In [31]:
d.video_data

Unnamed: 0,study,patient_id,patient_type,external_id,mode,trg_type,trg_view,trg_mode
2225,study22,01578f3a19fbf0f2,pda,study22_dicom58,2d,1,1,0
2296,study22,01578f3a19fbf0f2,pda,study22_dicom59,color_compare,1,1,2
2383,study22,01578f3a19fbf0f2,pda,study22_dicom63,color_compare,1,2,2
2480,study22,01578f3a19fbf0f2,pda,study22_dicom67,color_compare,1,2,2
2745,study22,01578f3a19fbf0f2,pda,study22_dicom69,2d,1,1,0
...,...,...,...,...,...,...,...,...
204109,study36,f8dd48f1f7946612,pda,study36_dicom50,2d,1,2,0
204139,study36,f8dd48f1f7946612,pda,study36_dicom52,color,1,2,1
205192,study36,f8dd48f1f7946612,pda,study36_dicom69,2d,1,1,0
205273,study36,f8dd48f1f7946612,pda,study36_dicom70,color,1,1,1


In [34]:
dl = DataLoader(d, batch_size=8, num_workers=0, collate_fn = d.collate, shuffle=True)

In [35]:
batch = next(iter(dl))

In [36]:
batch.keys()

dict_keys(['video', 'mask', 'trg_type', 'trg_view', 'trg_mode', 'study', 'patient'])

In [37]:
batch['video'].shape

torch.Size([910, 3, 224, 224])

In [40]:
batch['mask'], batch['mask'].shape

(tensor([[ True, False, False,  ..., False, False, False],
         [ True, False, False,  ..., False, False, False],
         [ True, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False,  True],
         [False, False, False,  ..., False, False,  True],
         [False, False, False,  ..., False, False,  True]]),
 torch.Size([910, 8]))

In [42]:
batch['trg_type']

tensor([1, 0, 1, 0, 0, 1, 0, 0])

In [43]:
batch['trg_view']

tensor([1, 1, 2, 2, 2, 1, 2, 2])

In [44]:
batch['trg_mode']

tensor([0, 0, 1, 0, 1, 1, 2, 2])

In [45]:
batch['study']

['study7',
 'study4',
 'study21',
 'study55',
 'study69',
 'study1',
 'study2',
 'study61']

In [46]:
batch['patient']

['0a3d6256c8e10c73',
 '627e389054c339f5',
 '17a19a86d4e64b41',
 'efa3614c94506474',
 'f1a239968416da4a',
 '3e70e49acc535a20',
 '07a6b40fa32bef68',
 'f0f5983a3b8324e2']