# Training the CNN

I will train a cnn on the images of the folder Dataset, this resulting CNN will be applied to each frame of a video and will return an embedding for each frame. This will be fed into a LSTM to classify the video as either real or fake



## Image preprocessing and data loaders

the images in the Dataset folder are 256x256 pixels, but the videos we are classifying with the LSTM will be preprocessed to be 64x64. For this reason, the CNN as well will be trained on images 64x64. The following code will resize the dataset and prepare the dataloader object to train the CNN


In [16]:
import torch
from torchvision.io import read_image
from torchvision.transforms.functional import resize
import os
from tqdm import tqdm

def resize_image(img_path):
    img = read_image(img_path)
    return resize(img, size = [64,64])

# testing the function
# import matplotlib.pyplot as plt
# before = read_image("Dataset/Validation/fake/fake_0.jpg")
# after = resize_image("Dataset/Validation/fake/fake_0.jpg")
# fig, axes = plt.subplots(nrows= 1, ncols= 2)
# axes = axes.flatten()
# axes[0].imshow(before.permute(1,2,0))
# axes[1].imshow(after.permute(1,2,0))

def resize_dataset(output_dir):
    for subset in ['Train', 'Validation', 'Test']:
        for class_name in ['real', 'fake']:
            input_dir = os.path.join("Dataset", subset, class_name)
            new_output_dir = os.path.join(output_dir, subset, class_name)
            os.makedirs(new_output_dir, exist_ok=True)
            
            img_files = [f for f in os.listdir(input_dir) if f.endswith('.jpg')]
            
            for img_file in tqdm(img_files, desc=f"Processing {subset} {class_name}"):
                video_path = os.path.join(input_dir, img_file)
                processed_image = resize_image(video_path)
                
                output_path = os.path.join(new_output_dir, img_file.replace('.jpg', '.pth'))
                torch.save(processed_image, output_path)

# resize_dataset("..\\resized_images")

Processing Train real: 100%|██████████| 70001/70001 [03:19<00:00, 351.50it/s]
Processing Train fake: 100%|██████████| 70001/70001 [03:22<00:00, 344.97it/s]
Processing Validation real: 100%|██████████| 19787/19787 [00:54<00:00, 365.66it/s]
Processing Validation fake: 100%|██████████| 19641/19641 [00:53<00:00, 366.82it/s]
Processing Test real: 100%|██████████| 5413/5413 [00:14<00:00, 366.39it/s]
Processing Test fake: 100%|██████████| 5492/5492 [00:15<00:00, 355.21it/s]


In [36]:
import os
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

class ImageDataset(Dataset):

    def __init__(self, processed_dir = "../resized_images", subset = "Train"):
        self.directory = processed_dir+"/"+subset
        self.images = self._get_img_paths()
        self._shuffle_data()
    
    def _get_img_paths(self):
        imgs = []
        for img_path in os.listdir(self.directory + "/Fake"):
            imgs.append((self.directory + "/Fake/"+img_path, 1))
        for img_path in os.listdir(self.directory + "/Real"):
            imgs.append((self.directory + "/Real/"+img_path, 0))
        return imgs
    
    def _shuffle_data(self):
        np.random.shuffle(self.images)
    
    def __getitem__(self,index):
        features, target = self.images[index]
        return torch.load(features), target
    
    def __len__(self):
        return len(self.images)
    

val_img_dataset = ImageDataset(subset="Validation")

val_img_loader = DataLoader(val_img_dataset, batch_size= 64, shuffle= True)

In [37]:
x, y = next(iter(val_img_loader))
print(f"{x.shape} is the shape of x\n{y.shape} is the shape of y")

torch.Size([64, 3, 64, 64]) is the shape of x
torch.Size([64]) is the shape of y
