<a href="https://colab.research.google.com/github/RaviSriTejaKuriseti/YogaPoseDetection/blob/main/Load.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import cv2
import numpy as np
import pandas as pd
from skimage import io,transform
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn import preprocessing
from tqdm import tqdm
import torch
from torch.autograd import Variable
from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout
from torch.optim import Adam, SGD
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.io import read_image
import glob
import os
from IPython.display import Image
import sys

class ImageDataset(Dataset):
    
    def __init__(self,data_csv,train=True,to_encode=True,img_transform=None):
        """
        Dataset init function
        
        INPUT:
        data_csv: Path to csv file containing [data, labels]
        train: 
            True: if the csv file has [labels,data] (Train data and Public Test Data) 
            False: if the csv file has only [data] and labels are not present.
        img_transform: List of preprocessing operations need to performed on image. 
        """
        self.data_csv = data_csv
        self.img_transform = img_transform
        self.is_train = train
        self.le=preprocessing.LabelEncoder()
        self.fit_to_encode=to_encode

        
        
    def __len__(self):
        """Returns total number of samples in the dataset"""
        return len(self.images)
    
    def __getitem__(self, idx):
        """
        Loads image of the given index and performs preprocessing.
        
        INPUT: 
        idx: index of the image to be loaded.
        
        OUTPUT:
        sample: dictionary with keys images (Tensor of shape [1,C,H,W]) and labels (Tensor of labels [1]).
        """
        data=pd.read_csv(self.data_csv, header=None)        
        img_path=data.iloc[idx, 0]
        image=read_image(img_path)
        image_labels=data.iloc[:,1]        
        if self.is_train:
            if(self.fit_to_encode):
              self.le.fit(image_labels)
            labels=self.le.transform(image_labels).astype(int)
            label = labels[idx]
        else:
            label=-1
        
        image = self.img_transform(image) 
        print(image.size)       
        sample = (image,label)
        return sample



def load_train_data(train_data):
  BATCH_SIZE = 200 # Batch Size. Adjust accordingly
  NUM_WORKERS = 20 # Number of threads to be used for image loading. Adjust accordingly.

  img_transforms = transforms.Compose([transforms.ToPILImage(),transforms.ToTensor()])

  # Train DataLoader
  # train_data = "" # Path to train csv file
  train_dataset = ImageDataset(data_csv = train_data, train=True,to_encode=True,img_transform=img_transforms)
  train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE,shuffle=False, num_workers = NUM_WORKERS)
  return train_loader


def load_test_data_with_labels(test_data):

  BATCH_SIZE = 200 # Batch Size. Adjust accordingly
  NUM_WORKERS = 20 # Number of threads to be used for image loading. Adjust accordingly.

  img_transforms = transforms.Compose([transforms.ToPILImage(),transforms.ToTensor()])
  # Test DataLoader
  # test_data = "" # Path to test csv file
  test_dataset = ImageDataset(data_csv = test_data, train=True,to_encode=False,img_transform=img_transforms)
  test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE,shuffle=False, num_workers = NUM_WORKERS)
  return test_loader



def load_test_data_without_labels(test_data):
  BATCH_SIZE = 200 # Batch Size. Adjust accordingly
  NUM_WORKERS = 20 # Number of threads to be used for image loading. Adjust accordingly.

  img_transforms = transforms.Compose([transforms.ToPILImage(),transforms.ToTensor()])

  # Test DataLoader
  # test_data = "" # Path to test csv file
  test_dataset = ImageDataset(data_csv = test_data, train=False,to_encode=False,img_transform=img_transforms)
  test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE,shuffle=False, num_workers = NUM_WORKERS)
  return test_loader

