In [29]:
import torch
import torchvision
import opendatasets as od
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import os
import random
import shutil
from tqdm.auto import tqdm

In [2]:
od.download("https://www.kaggle.com/datasets/antobenedetti/animals")

Dataset URL: https://www.kaggle.com/datasets/antobenedetti/animals
Downloading animals.zip to .\animals


100%|██████████| 882M/882M [00:59<00:00, 15.6MB/s] 





In [5]:
class AnimalDataset(Dataset):
    def __init__(self, splits,transform=None):
        self.dataset = ConcatDataset([torchvision.datasets.ImageFolder(split, transform=transform) for split in splits])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

In [7]:
len(os.listdir("animals/animals/train/cat"))

2737

In [8]:
len(os.listdir("animals/animals/train/dog"))

2627

In [9]:
len(os.listdir("animals/animals/train/elephant"))

2730

In [10]:
len(os.listdir("animals/animals/train/lion"))

2675

In [11]:
len(os.listdir("animals/animals/train/horse"))

2705

In [14]:
cats = os.listdir("animals/animals/train/cat")
dogs = os.listdir("animals/animals/train/dog")
elephants = os.listdir("animals/animals/train/elephant")
lions = os.listdir("animals/animals/train/lion")
horses = os.listdir("animals/animals/train/horse")

In [15]:
random.shuffle(cats)
random.shuffle(dogs)
random.shuffle(elephants)
random.shuffle(lions)
random.shuffle(horses)
train_cats, test_cats = cats[:int(len(cats)*0.8)], cats[int(len(cats)*0.8):]
train_dogs, test_dogs = dogs[:int(len(dogs)*0.8)], dogs[int(len(dogs)*0.8):]
train_elephants, test_elephants = elephants[:int(len(elephants)*0.8)], elephants[int(len(elephants)*0.8):]
train_lions, test_lions = lions[:int(len(lions)*0.8)], lions[int(len(lions)*0.8):]
train_horses, test_horses = horses[:int(len(horses)*0.8)], horses[int(len(horses)*0.8):]

In [16]:
def generate_folds(list_of_imgs, n_folds):
    fold_size = len(list_of_imgs)//n_folds
    folds = []
    for i in range(n_folds):
        folds.append(list_of_imgs[i*fold_size:(i+1)*fold_size])
    return folds

In [21]:
train_cats_folds = generate_folds(train_cats, 10)
train_dogs_folds = generate_folds(train_dogs, 10)
train_elephants_folds = generate_folds(train_elephants, 10)
train_lions_folds = generate_folds(train_lions, 10)
train_horses_folds = generate_folds(train_horses, 10)

In [22]:
[len(fold) for fold in train_cats_folds]

[218, 218, 218, 218, 218, 218, 218, 218, 218, 218]

In [30]:
def move_images(root_dir, target_dir, list_of_imgs):
    for img in tqdm(list_of_imgs,"progress"):
        shutil.move(os.path.join(root_dir, img), os.path.join(target_dir, img))

In [34]:
def move_folds(root_dir, target_dir, folds):
    for idx,fold in enumerate(folds):
        print(f"Moving fold {idx}")
        os.mkdir(os.path.join(target_dir,f"fold_{idx}"))
        move_images(root_dir, os.path.join(target_dir,f"fold_{idx}"), fold)

In [35]:
move_folds("animals/animals/train/cat", "animals_folds/train/cat", train_cats_folds)

Moving fold 0


progress:   0%|          | 0/218 [00:00<?, ?it/s]

progress: 100%|██████████| 218/218 [00:00<00:00, 796.56it/s]


Moving fold 1


progress: 100%|██████████| 218/218 [00:00<00:00, 595.62it/s]


Moving fold 2


progress: 100%|██████████| 218/218 [00:00<00:00, 644.27it/s]


Moving fold 3


progress: 100%|██████████| 218/218 [00:00<00:00, 663.06it/s]


Moving fold 4


progress: 100%|██████████| 218/218 [00:00<00:00, 664.17it/s]


Moving fold 5


progress: 100%|██████████| 218/218 [00:00<00:00, 610.36it/s]


Moving fold 6


progress: 100%|██████████| 218/218 [00:00<00:00, 858.44it/s]


Moving fold 7


progress: 100%|██████████| 218/218 [00:00<00:00, 832.90it/s]


Moving fold 8


progress: 100%|██████████| 218/218 [00:00<00:00, 728.91it/s] 


Moving fold 9


progress: 100%|██████████| 218/218 [00:00<00:00, 575.49it/s]


In [36]:
move_folds("animals/animals/train/dog", "animals_folds/train/dog", train_dogs_folds)

Moving fold 0


progress: 100%|██████████| 210/210 [00:00<00:00, 489.63it/s]


Moving fold 1


progress: 100%|██████████| 210/210 [00:00<00:00, 691.46it/s]


Moving fold 2


progress: 100%|██████████| 210/210 [00:00<00:00, 682.50it/s]


Moving fold 3


progress: 100%|██████████| 210/210 [00:00<00:00, 663.79it/s]


Moving fold 4


progress: 100%|██████████| 210/210 [00:00<00:00, 660.66it/s]


Moving fold 5


progress: 100%|██████████| 210/210 [00:00<00:00, 613.14it/s]


Moving fold 6


progress: 100%|██████████| 210/210 [00:00<00:00, 585.18it/s]


Moving fold 7


progress: 100%|██████████| 210/210 [00:00<00:00, 777.05it/s]


Moving fold 8


progress: 100%|██████████| 210/210 [00:00<00:00, 797.56it/s]


Moving fold 9


progress: 100%|██████████| 210/210 [00:00<00:00, 767.15it/s]


In [37]:
move_folds("animals/animals/train/lion", "animals_folds/train/lion", train_lions_folds)

Moving fold 0


progress: 100%|██████████| 214/214 [00:00<00:00, 828.00it/s]


Moving fold 1


progress: 100%|██████████| 214/214 [00:00<00:00, 565.64it/s]


Moving fold 2


progress: 100%|██████████| 214/214 [00:00<00:00, 710.09it/s]


Moving fold 3


progress: 100%|██████████| 214/214 [00:00<00:00, 588.09it/s]


Moving fold 4


progress: 100%|██████████| 214/214 [00:00<00:00, 547.10it/s]


Moving fold 5


progress: 100%|██████████| 214/214 [00:00<00:00, 684.54it/s]


Moving fold 6


progress: 100%|██████████| 214/214 [00:00<00:00, 794.29it/s]


Moving fold 7


progress: 100%|██████████| 214/214 [00:00<00:00, 955.18it/s]


Moving fold 8


progress: 100%|██████████| 214/214 [00:00<00:00, 777.22it/s]


Moving fold 9


progress: 100%|██████████| 214/214 [00:00<00:00, 768.60it/s]


In [38]:
move_folds("animals/animals/train/elephant", "animals_folds/train/elephant", train_elephants_folds)

Moving fold 0


progress: 100%|██████████| 218/218 [00:00<00:00, 785.66it/s]


Moving fold 1


progress: 100%|██████████| 218/218 [00:00<00:00, 1221.77it/s]


Moving fold 2


progress: 100%|██████████| 218/218 [00:00<00:00, 918.19it/s] 


Moving fold 3


progress: 100%|██████████| 218/218 [00:00<00:00, 1006.29it/s]


Moving fold 4


progress: 100%|██████████| 218/218 [00:00<00:00, 945.78it/s]


Moving fold 5


progress: 100%|██████████| 218/218 [00:00<00:00, 846.13it/s]


Moving fold 6


progress: 100%|██████████| 218/218 [00:00<00:00, 902.85it/s] 


Moving fold 7


progress: 100%|██████████| 218/218 [00:00<00:00, 791.82it/s]


Moving fold 8


progress: 100%|██████████| 218/218 [00:00<00:00, 860.78it/s]


Moving fold 9


progress: 100%|██████████| 218/218 [00:00<00:00, 697.47it/s]


In [39]:
move_folds("animals/animals/train/horse", "animals_folds/train/horse", train_horses_folds)

Moving fold 0


progress: 100%|██████████| 216/216 [00:00<00:00, 736.73it/s]


Moving fold 1


progress: 100%|██████████| 216/216 [00:00<00:00, 786.62it/s]


Moving fold 2


progress: 100%|██████████| 216/216 [00:00<00:00, 719.99it/s]


Moving fold 3


progress: 100%|██████████| 216/216 [00:00<00:00, 554.80it/s]


Moving fold 4


progress: 100%|██████████| 216/216 [00:00<00:00, 912.00it/s]


Moving fold 5


progress: 100%|██████████| 216/216 [00:00<00:00, 754.30it/s]


Moving fold 6


progress: 100%|██████████| 216/216 [00:00<00:00, 666.63it/s]


Moving fold 7


progress: 100%|██████████| 216/216 [00:00<00:00, 880.68it/s] 


Moving fold 8


progress: 100%|██████████| 216/216 [00:00<00:00, 594.52it/s]


Moving fold 9


progress: 100%|██████████| 216/216 [00:00<00:00, 695.61it/s]


In [42]:
total_files = len([
    f for fold in os.listdir("animals_folds/train/cat")  # Iterate through subdirectories
    for f in os.listdir(f"animals_folds/train/cat/{fold}")  # Iterate through files in each subdirectory
])

In [44]:
total_files

2180

In [45]:
move_images("animals/animals/train/cat", "animals_folds/test/cat", test_cats)

progress: 100%|██████████| 548/548 [00:00<00:00, 611.40it/s]


In [47]:
move_images("animals/animals/train/dog", "animals_folds/test/dog", test_dogs)

progress: 100%|██████████| 526/526 [00:00<00:00, 616.44it/s]


In [48]:
move_images("animals/animals/train/elephant", "animals_folds/test/elephant", test_elephants)

progress: 100%|██████████| 546/546 [00:00<00:00, 647.86it/s]


In [49]:
move_images("animals/animals/train/lion", "animals_folds/test/lion", test_lions)

progress: 100%|██████████| 535/535 [00:00<00:00, 699.65it/s]


In [50]:
move_images("animals/animals/train/horse", "animals_folds/test/horse", test_horses)

progress: 100%|██████████| 541/541 [00:00<00:00, 912.37it/s]
