In [None]:
!pip install kaggle



In [None]:
from google.colab import files

uploaded = files.upload()

Saving kaggle.json to kaggle (1).json


In [None]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d ravidussilva/real-ai-art

Downloading real-ai-art.zip to /content
100% 9.94G/9.95G [01:52<00:00, 145MB/s]
100% 9.95G/9.95G [01:52<00:00, 94.9MB/s]


In [None]:
!unzip real-ai-art.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_114.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_116.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_142.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_154.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_183.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_186.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_201.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_219.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_230.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_274.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_279.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_286.jpg  
  inflating: Real_AI_SD_LD_Dataset/train/ukiyo_e/adachi-ginko_3.jpg  
 

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import os
from torchvision.io import read_image

class AIArtBenchDataset(Dataset):
    def __init__(self, root, for_training=True, transforms=None, target_transforms=None):
        self.for_training = for_training
        self.transforms = transforms
        self.target_transforms = target_transforms

        if self.for_training:
            self.root = os.path.join(root, 'train')
        else:
            self.root = os.path.join(root, 'test')

        self.fnames_list = []
        for directory in os.listdir(self.root):
            for image in os.listdir(os.path.join(self.root, directory)):
                self.fnames_list.append(os.path.join(directory, image))

    def __len__(self):
        return len(self.fnames_list)

    def __getitem__(self, idx):
        img_name = self.fnames_list[idx]
        img_path = os.path.join(self.root, img_name)
        image = read_image(img_path)

        if 'AI' in img_name:
            if 'SD' in img_name:
                label = 0 # 0 for standard diffusion art
            else:
                label = 1 # 1 for latent diffusion art
        else:
            label = 2 # 2 for real art

        if self.transforms:
            image = self.transforms(image)
        if self.target_transforms:
            label = self.target_transforms(label)

        return image, label

In [None]:
import numpy as np
import random
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.transforms import Resize, Normalize, Compose, ToTensor, Lambda
import torchvision.transforms.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform = transforms.Compose([
    Lambda(lambda x: x.float()),  # Convert images to float tensors
    transforms.Resize((299, 299)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Assuming AIArtBenchDataset is a custom dataset class
train_dataset = AIArtBenchDataset(root='/content/Real_AI_SD_LD_Dataset', for_training=True, transforms=transform)

# Set seed for reproducibility
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)