In [2]:
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
import os
import torch
import torch.nn as nn
from torchvision import transforms, models, datasets
from torch.nn import functional as F
import cv2
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import torch.optim as optim

In [3]:
from collections import OrderedDict
import matplotlib.pyplot as plt
from PIL import Image

In [4]:
import timm

### Data path

In [5]:
labels_csv = {'train': "/media/brats/DRIVE1/akansh/Vin-ChestXR-Abnormality-detection/Data/Processed/image_labels_train.csv",
             'test': "/media/brats/DRIVE1/akansh/Vin-ChestXR-Abnormality-detection/Data/Processed/image_labels_test.csv"
             }

data_dir = {'train': "/media/brats/mirlproject2/vinbigdata-chest-xray-abnormalities-detection/vinbig_png/",
           'test': "/media/brats/mirlproject2/vinbigdata-chest-xray-abnormalities-detection/test_png/"}

### Dataset class

In [None]:
class Vin_big_dataset(Dataset):
    def __init__(self, image_loc, label_loc, transforms, data_type = 'train'):
        global_labels = ['Pleural effusion', 'Lung tumor', 'Pneumonia', 'Tuberculosis', 'Other diseases', 'No finding']
        
        if data_type == 'train':
            label_df = pd.read_csv(label_loc)
            label_df['labels'] = label_df['image_id'] +'_'+ label_df['rad_id']
            label_df.set_index("labels", inplace = True)
            filenames = label_df.index.values.tolist()
            
            self.full_filenames = [os.path.join(image_loc, i.split('_')[0]+'.png') for i in filenames]
            self.labels = []
            for i in tqdm(filenames):
                self.labels.append(label_df[global_labels].loc[i].values.tolist())         
            self.labels = torch.tensor(self.labels)
        if data_type == 'test':                     
            filenames = os.listdir(image_loc)
            self.full_filenames = [os.path.join(image_loc, i) for i in filenames]
            label_df = pd.read_csv(label_loc)
            label_df.set_index("image_id", inplace = True)
            self.labels = [label_df[global_labels].loc[filename[:-4]].values for filename in filenames]
            
        self.transforms = transforms
#         self.data_type = data_type
    def __len__(self):
        return len(self.full_filenames)
    
    def __getitem__(self, idx):
        image = Image.open(self.full_filenames[idx])
        image = np.array(image, dtype = 'uint8')
        image = self.transforms(image)
        
        return image, self.labels[idx]
    
            