###  Import the necessary libraries

In [1]:
import csv
import numpy as np
import os
from random import sample
from sklearn.utils import shuffle
import re          
import zipfile     
import pandas as pd 
import tensorflow as tf
from PIL import Image
from torchvision import transforms
from PIL import ImageFile   
import random  

# Data Preparation

## 1. Data Loading and cleaning

**This code loads a dataset from a CSV file, and then checks for and removes any rows that correspond to files with corrupted data. The corrupted data is assumed to be stored in a zip archive, and the code uses regex to extract the IDs of the corrupted files from the filenames. Once the IDs of the corrupted files are identified, the corresponding rows in the DataFrame are removed, and the cleaned dataset is written back to the original CSV file.**

In [2]:
# Load the data from the csv file
df = pd.read_csv('all_data_info.csv')

# Load the zip archive containing replacement files
archive = zipfile.ZipFile('replacements_for_corrupted_files.zip', 'r')

# Create a set of IDs for corrupted files
corrupted_ids = set()

# Loop through all files in the zip archive and extract the IDs
# Only add IDs to the set if they contain at least one number
for item in archive.namelist():
    ID = re.sub("[^0-9]", "", item)
    if ID != "":
        corrupted_ids.add(ID)

# Create a list of indices for rows to be dropped from the DataFrame
drop_idx = []

# Loop through all rows in the DataFrame and check if their ID is in the set of corrupted IDs
# If the ID is found, add the row index to the list of indices to be dropped
for index, row in df.iterrows():
    id_check = re.sub("[^0-9]", "", row['new_filename'])
    if id_check in corrupted_ids:
        drop_idx.append(index)

# Drop the rows with indices in the drop_idx list from the DataFrame
df = df.drop(drop_idx)

# Delete the file if it already exists
if os.path.exists('all_data_info_1.csv'):
    os.remove('all_data_info_1.csv')
    
# Write the cleaned DataFrame back to the csv file
df.to_csv('all_data_info_1.csv',index=False)


In [3]:
df.head()

Unnamed: 0,artist,date,genre,pixelsx,pixelsy,size_bytes,source,style,title,artist_group,in_train,new_filename
0,Barnett Newman,1955.0,abstract,15530.0,6911.0,9201912.0,wikiart,Color Field Painting,Uriel,train_only,True,102257.jpg
1,Barnett Newman,1950.0,abstract,14559.0,6866.0,8867532.0,wikiart,Color Field Painting,Vir Heroicus Sublimis,train_only,True,75232.jpg
2,kiri nichol,2013.0,,9003.0,9004.0,1756681.0,,Neoplasticism,,test_only,False,32145.jpg
3,kiri nichol,2013.0,,9003.0,9004.0,1942046.0,,Neoplasticism,,test_only,False,20304.jpg
4,kiri nichol,2013.0,,9003.0,9004.0,1526212.0,,Neoplasticism,,test_only,False,836.jpg


**The above code is using the Python's zipfile module to extract the contents of two zip files train_3.zip and test.zip.**

**The first block of code uses the with statement and creates a ZipFile object by passing the file path of the train_3.zip file and opens it in read mode ('r'). The extractall() method is then called on this object to extract all the contents of the zip file into the directory named 'train'.**

**The second block of code follows a similar pattern, but this time it opens the test.zip file and extracts its contents to a directory named 'test'.**

**In summary, these two blocks of code extract the contents of two zip files and store them in two separate directories.**

In [5]:
path = './'

In [6]:
with zipfile.ZipFile(path+'train_2.zip', 'r') as zip_ref:
    zip_ref.extractall('train')

In [19]:
with zipfile.ZipFile(path+'test.zip', 'r') as zip_ref:
    zip_ref.extractall('test')

In [8]:
# iterate over all the images in the "train/train_3" folder
for filename in os.listdir("train/train_2"):
    # open the image and check if it has extra channels
    image =  Image.open(os.path.join("train/train_2", filename)) 
    if len(image.getbands()) > 3:
        print(f"Skipping {filename} as it has extra channels other than RGB.")
        image.close()
        os.remove(os.path.join("train/train_2", filename))

Skipping 21733.jpg as it has extra channels other than RGB.
Skipping 22580.jpg as it has extra channels other than RGB.
Skipping 24241.jpg as it has extra channels other than RGB.
Skipping 27735.jpg as it has extra channels other than RGB.
Skipping 2881.jpg as it has extra channels other than RGB.
Skipping 29815.jpg as it has extra channels other than RGB.
Skipping 29854.jpg as it has extra channels other than RGB.


In [23]:
# iterate over all the images in the "train/train_3" folder
for filename in tqdm(os.listdir("test/test")):
    # open the image and check if it has extra channels
    try:
        image =  Image.open(os.path.join("test/test", filename)) 
    except:
        os.remove(os.path.join("test/test", filename))
        continue
    if len(image.getbands()) > 3:
        print(f"Skipping {filename} as it has extra channels other than RGB.")
        image.close()
        os.remove(os.path.join("test/test", filename))

 20%|██████████████▋                                                            | 4668/23815 [00:01<00:10, 1805.20it/s]

Skipping 25416.jpg as it has extra channels other than RGB.


 26%|███████████████████▋                                                        | 6166/23815 [00:06<00:40, 430.58it/s]

Skipping 30623.jpg as it has extra channels other than RGB.


 28%|████████████████████▉                                                       | 6569/23815 [00:08<00:49, 349.10it/s]

Skipping 32071.jpg as it has extra channels other than RGB.


 29%|█████████████████████▉                                                      | 6862/23815 [00:09<00:58, 289.81it/s]

Skipping 33273.jpg as it has extra channels other than RGB.


 32%|████████████████████████▎                                                   | 7599/23815 [00:11<00:52, 306.43it/s]

Skipping 35962.jpg as it has extra channels other than RGB.


 35%|██████████████████████████▋                                                 | 8347/23815 [00:14<00:51, 300.92it/s]

Skipping 39101.jpg as it has extra channels other than RGB.


 40%|██████████████████████████████▍                                             | 9552/23815 [00:18<00:47, 303.06it/s]

Skipping 44017.jpg as it has extra channels other than RGB.


 46%|██████████████████████████████████▊                                        | 11053/23815 [00:23<00:42, 297.33it/s]

Skipping 49778.jpg as it has extra channels other than RGB.


 47%|███████████████████████████████████                                        | 11147/23815 [00:24<00:41, 303.17it/s]

Skipping 5018.jpg as it has extra channels other than RGB.


 49%|█████████████████████████████████████                                      | 11784/23815 [00:26<00:41, 286.50it/s]

Skipping 52742.jpg as it has extra channels other than RGB.


 62%|██████████████████████████████████████████████▏                            | 14672/23815 [00:37<00:30, 303.12it/s]

Skipping 64004.jpg as it has extra channels other than RGB.


 67%|██████████████████████████████████████████████████▌                        | 16040/23815 [00:41<00:26, 290.69it/s]

Skipping 69423.jpg as it has extra channels other than RGB.


 68%|██████████████████████████████████████████████████▊                        | 16135/23815 [00:42<00:25, 306.59it/s]

Skipping 69828.jpg as it has extra channels other than RGB.
Skipping 70061.jpg as it has extra channels other than RGB.


 71%|█████████████████████████████████████████████████████▎                     | 16917/23815 [00:44<00:23, 298.57it/s]

Skipping 72972.jpg as it has extra channels other than RGB.


 74%|███████████████████████████████████████████████████████▏                   | 17523/23815 [00:46<00:21, 297.86it/s]

Skipping 75551.jpg as it has extra channels other than RGB.


 75%|████████████████████████████████████████████████████████▌                  | 17962/23815 [00:48<00:19, 307.22it/s]

Skipping 77299.jpg as it has extra channels other than RGB.


 88%|██████████████████████████████████████████████████████████████████▏        | 21026/23815 [00:58<00:09, 298.54it/s]

Skipping 89073.jpg as it has extra channels other than RGB.


 91%|████████████████████████████████████████████████████████████████████       | 21615/23815 [01:01<00:08, 269.53it/s]

Skipping 91447.jpg as it has extra channels other than RGB.


 92%|█████████████████████████████████████████████████████████████████████▏     | 21963/23815 [01:02<00:05, 311.54it/s]

Skipping 92799.jpg as it has extra channels other than RGB.


 93%|█████████████████████████████████████████████████████████████████████▌     | 22086/23815 [01:02<00:06, 280.87it/s]

Skipping 93262.jpg as it has extra channels other than RGB.


 93%|█████████████████████████████████████████████████████████████████████▉     | 22207/23815 [01:03<00:05, 290.47it/s]

Skipping 93738.jpg as it has extra channels other than RGB.


 94%|██████████████████████████████████████████████████████████████████████▋    | 22442/23815 [01:03<00:05, 258.39it/s]

Skipping 94705.jpg as it has extra channels other than RGB.


100%|██████████████████████████████████████████████████████████████████████████▊| 23773/23815 [01:08<00:00, 298.16it/s]

Skipping 99592.jpg as it has extra channels other than RGB.


100%|███████████████████████████████████████████████████████████████████████████| 23815/23815 [01:08<00:00, 345.80it/s]


In [10]:
from struct import unpack
from tqdm import tqdm
import os


marker_mapping = {
    0xffd8: "Start of Image",
    0xffe0: "Application Default Header",
    0xffdb: "Quantization Table",
    0xffc0: "Start of Frame",
    0xffc4: "Define Huffman Table",
    0xffda: "Start of Scan",
    0xffd9: "End of Image"
}


class JPEG:
    def __init__(self, image_file):
        with open(image_file, 'rb') as f:
            self.img_data = f.read()
    
    def decode(self):
        data = self.img_data
        while(True):
            marker, = unpack(">H", data[0:2])
            # print(marker_mapping.get(marker))
            if marker == 0xffd8:
                data = data[2:]
            elif marker == 0xffd9:
                return
            elif marker == 0xffda:
                data = data[-2:]
            else:
                lenchunk, = unpack(">H", data[2:4])
                data = data[2+lenchunk:]            
            if len(data)==0:
                break        


bads = []

for img in tqdm(os.listdir("train/train_2")):
    image = os.path.join("train/train_2",img)
    image = JPEG(image) 
    try:
        image.decode()   
    except:
        bads.append(img)


for name in tqdm(bads):
    os.remove(os.path.join("train/train_2",name))

100%|█████████████████████████████████████████████████████████████████████████████| 8469/8469 [00:09<00:00, 850.92it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 51/51 [00:00<00:00, 5663.78it/s]


In [24]:
bads = []

for img in tqdm(os.listdir("test/test")):
    image = os.path.join("test/test",img)
    image = JPEG(image) 
    try:
        image.decode()   
    except:
        bads.append(img)


for name in tqdm(bads):
    os.remove(os.path.join("test/test",name))

100%|███████████████████████████████████████████████████████████████████████████| 23789/23789 [01:31<00:00, 260.91it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 5965.94it/s]


In [12]:
# loop through each file in the folder
for filename in os.listdir("train/train_2"):
    if filename.endswith('.jpg') or filename.endswith('.png'): # change extensions as needed
        filepath = os.path.join("train/train_2", filename)
        
        # open the image and get its mode and bands
        with Image.open(filepath) as img:
            mode = img.mode
            bands = img.getbands()
            
            # check if the image is grayscale or has extra channels with zero information
            if (mode == 'L') or (len(bands) > 3 and bands[-1] == 'A'):
                t=True
            else:
                t=False
        if t: os.remove(filepath)

In [25]:
# loop through each file in the folder
for filename in os.listdir("test/test"):
    if filename.endswith('.jpg') or filename.endswith('.png'): # change extensions as needed
        filepath = os.path.join("test/test", filename)
        
        # open the image and get its mode and bands
        with Image.open(filepath) as img:
            mode = img.mode
            bands = img.getbands()
            
            # check if the image is grayscale or has extra channels with zero information
            if (mode == 'L') or (len(bands) > 3 and bands[-1] == 'A'):
                t=True
            else:
                t=False
        if t: os.remove(filepath)

**This code defines a function called del_image that takes an image file path as an argument. It reads the image file, decodes and resizes the image, and preprocesses the image using the VGG16 model. If an exception is raised during this process, the function prints a message indicating that the image is corrupted and then removes the file using the os.remove() function. The purpose of this function is to remove corrupted images from the dataset, which can cause errors during training.**

In [13]:
def del_image(image_path):
    try:
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, (224, 224))
        image = tf.keras.applications.vgg16.preprocess_input(image)
    except:
        print(f'Removing corrupted image: {image_path}')
        os.remove(image_path)

**Now, This code deletes corrupted images in a given folder train/train_3 by calling the del_image function on each file.**

In [14]:
train_folder = 'train/train_2'
for filename in tqdm(os.listdir(train_folder)):
        image_path = os.path.join(train_folder, filename)
        del_image(image_path)

100%|██████████████████████████████████████████████████████████████████████████████| 8338/8338 [02:15<00:00, 61.59it/s]


In [26]:
train_folder = 'test/test'
for filename in tqdm(os.listdir(train_folder)):
        image_path = os.path.join(train_folder, filename)
        del_image(image_path)

100%|████████████████████████████████████████████████████████████████████████████| 23450/23450 [07:08<00:00, 54.70it/s]


In [15]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [18]:
from PIL import Image
import os

# define input and output directories
input_dir = 'train/train_2'
output_dir = 'train/train_crop'

# define target size for the center crop
target_size = (224, 224)

# create the output directory if it does not exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# loop through each file in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith('.jpg') or filename.endswith('.jpeg') or filename.endswith('.png'):
        # open the image file
        try:
            image = Image.open(os.path.join(input_dir, filename))

            # calculate the center crop box
            width, height = image.size
            left = (width - target_size[0]) / 2
            top = (height - target_size[1]) / 2
            right = (width + target_size[0]) / 2
            bottom = (height + target_size[1]) / 2

            # crop the image
            image = image.crop((left, top, right, bottom))

            # save the cropped image to the output directory
            image.save(os.path.join(output_dir, filename))
        except:
            continue

In [27]:
from PIL import Image
import os

# define input and output directories
input_dir = 'test/test'
output_dir = 'test/test_crop'

# define target size for the center crop
target_size = (224, 224)

# create the output directory if it does not exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# loop through each file in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith('.jpg') or filename.endswith('.jpeg') or filename.endswith('.png'):
        # open the image file
        try:
            image = Image.open(os.path.join(input_dir, filename))

            # calculate the center crop box
            width, height = image.size
            left = (width - target_size[0]) / 2
            top = (height - target_size[1]) / 2
            right = (width + target_size[0]) / 2
            bottom = (height + target_size[1]) / 2

            # crop the image
            image = image.crop((left, top, right, bottom))

            # save the cropped image to the output directory
            image.save(os.path.join(output_dir, filename))
        except:
            continue

**The above code crops the images, using centric croping to reduce training complexity**