In [3]:
import torch
import torchvision.transforms as transforms 
import pandas as pd
from PIL import Image

In [4]:
class StartingDataset(torch.utils.data.Dataset):
    def __init__(self, df_path="cass_data/train.csv", img_path="cass_data/train_images", train=True):
        df = pd.read_csv(df_path).sample(frac = 1, random_state=42).reset_index(drop=True)

        train_percentage = 0.8
        rows = df.shape[0]
        train_rows = int(rows * train_percentage)

        if train:
            self.df = df.iloc[:train_rows]
        else:
            self.df = df.iloc[train_rows:]
        
        self.df = self.df.reset_index(drop=True)

        self.img_path = img_path

        self.transforms = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
        ])

    def __getitem__(self, index):
        jpg_str = str(self.df.loc[index]['image_id'])
        label = self.df.loc[index]['label']

        with Image.open(f"{self.img_path}/{jpg_str}") as im:
            im = self.transforms(im)

        print("label:", label)
        return im, label

    def __len__(self):
        return len(self.df)

In [5]:
df_path = "../cass_data/train.csv"
img_path = "../cass_data/train_images"

In [6]:
train = StartingDataset(df_path=df_path, img_path=img_path, train=False)

In [7]:
len(train)

4280

In [8]:
train.df.head()

Unnamed: 0,image_id,label
0,2824543301.jpg,3
1,184909120.jpg,3
2,2602456265.jpg,3
3,1331491784.jpg,3
4,414363375.jpg,3


In [9]:
train[30]

label: 3


(tensor([[[0.1961, 0.1765, 0.1216,  ..., 0.2431, 0.3412, 0.4471],
          [0.1490, 0.1137, 0.1020,  ..., 0.1961, 0.3294, 0.5569],
          [0.1490, 0.1451, 0.1216,  ..., 0.1882, 0.3333, 0.4902],
          ...,
          [0.0196, 0.0353, 0.0588,  ..., 0.2745, 0.1725, 0.0745],
          [0.0196, 0.0353, 0.0667,  ..., 0.3529, 0.3137, 0.2039],
          [0.0196, 0.0353, 0.0588,  ..., 0.4549, 0.2980, 0.4353]],
 
         [[0.1529, 0.1333, 0.0863,  ..., 0.1725, 0.2667, 0.3686],
          [0.0902, 0.0588, 0.0471,  ..., 0.1255, 0.2588, 0.4824],
          [0.0784, 0.0784, 0.0627,  ..., 0.1176, 0.2627, 0.4157],
          ...,
          [0.0157, 0.0314, 0.0549,  ..., 0.2275, 0.1255, 0.0431],
          [0.0157, 0.0314, 0.0627,  ..., 0.3020, 0.2627, 0.1686],
          [0.0157, 0.0314, 0.0549,  ..., 0.4039, 0.2471, 0.4000]],
 
         [[0.0627, 0.0784, 0.0431,  ..., 0.1176, 0.2039, 0.2784],
          [0.0549, 0.0471, 0.0471,  ..., 0.0706, 0.1843, 0.3765],
          [0.0549, 0.0667, 0.0549,  ...,