In [1]:
# make imports

import os
import sys
import csv
import PIL
import numpy as np
from numpy import asarray
import pickle
from PIL import Image, ImageEnhance
from matplotlib import image,pyplot
import requests
from io import BytesIO

In [2]:
def read_csv(file_name):
    
    # check if the file exists
    if os.path.exists(file_name):

        # open the file
        with open(file_name, 'r') as file:

            # read the file
            reader = csv.reader(file)

            # return the data
            return list(reader)
        
    return None

given_data = read_csv('A2_Data.csv')

In [3]:
given_data = given_data[1:]

# download images and store image data in an array [image_data, prod_id, image_url]

all_image_data = []

for i in range(len(given_data)):
    
    # get the prod_id
    prod_id = given_data[i][0]

    # get the image_urls
    url_lists = given_data[i][1][1:-1].split(',')

    for url in url_lists:

        cleaned_url = url.strip()
        cleaned_url = cleaned_url[1:-1]

        # download the image using requests library
        response = requests.get(cleaned_url)

        # open the image
        try:
            img = Image.open(BytesIO(response.content))
        except:
            print("Invalid Image, Skipping...")
            print("the image info is: ", prod_id, cleaned_url)
            continue

        # convert the image to numpy array
        img_array = asarray(img)

        # append the image data to the list
        all_image_data.append([img_array, prod_id, cleaned_url])

Invalid Image, Skipping...
the image info is:  2235 https://images-na.ssl-images-amazon.com/images/I/71F3npeHUDL._SY88.jpg
Invalid Image, Skipping...
the image info is:  2235 https://images-na.ssl-images-amazon.com/images/I/71wHUWncMGL._SY88.jpg
Invalid Image, Skipping...
the image info is:  3317 https://images-na.ssl-images-amazon.com/images/I/71B8OOE5N8L._SY88.jpg
Invalid Image, Skipping...
the image info is:  3317 https://images-na.ssl-images-amazon.com/images/I/81SX3oAWbNL._SY88.jpg
Invalid Image, Skipping...
the image info is:  2912 https://images-na.ssl-images-amazon.com/images/I/718niQ1GEwL._SY88.jpg
Invalid Image, Skipping...
the image info is:  2265 https://images-na.ssl-images-amazon.com/images/I/61OboZT-kcL._SY88.jpg
Invalid Image, Skipping...
the image info is:  2088 https://images-na.ssl-images-amazon.com/images/I/710a2Pyh5lL._SY88.jpg
Invalid Image, Skipping...
the image info is:  3474 https://images-na.ssl-images-amazon.com/images/I/816NMd0LexL._SY88.jpg


In [4]:
print("Total number of images downloaded: ", len(all_image_data))

Total number of images downloaded:  1640


In [5]:
# pre-process these images, resize them to a standard format

for i in range(len(all_image_data)):

    # get the image
    img = all_image_data[i][0]

    # convert the image to PIL image
    img = Image.fromarray(img)

    # resize the image
    img = img.resize((224,224))

    # convert the image to numpy array
    img = asarray(img)

    # append the image data to the list
    all_image_data[i][0] = img

print("Pre-processing and resizing done...")

Pre-processing and resizing done...


In [6]:
# augment the data

NUM_IMAGES = len(all_image_data)

for i in range(NUM_IMAGES):

    img = all_image_data[i][0]
    prod_id = all_image_data[i][1]
    url = all_image_data[i][2]

    img = Image.fromarray(img)

    # scale pixel values to produce contrast
    curr_image = Image.eval(img, lambda x: x*1.5)
    all_image_data.append([asarray(curr_image), prod_id, url])

    # flip left right
    curr_image = img.transpose(PIL.Image.FLIP_LEFT_RIGHT)
    all_image_data.append([asarray(curr_image), prod_id, url])

    # flip up down
    curr_image = img.transpose(PIL.Image.FLIP_TOP_BOTTOM)
    all_image_data.append([asarray(curr_image), prod_id, url])

    # rotate 90, 180, 270
    curr_image = img.rotate(90)
    all_image_data.append([asarray(curr_image), prod_id, url])

    curr_image = img.rotate(180)
    all_image_data.append([asarray(curr_image), prod_id, url])

    curr_image = img.rotate(270)
    all_image_data.append([asarray(curr_image), prod_id, url])

    # add variations of brightness and exposure
    enhancer = ImageEnhance.Brightness(img)
    curr_image = enhancer.enhance(0.5)

    all_image_data.append([asarray(curr_image), prod_id, url])

    enhancer = ImageEnhance.Brightness(img)
    curr_image = enhancer.enhance(1.5)

    all_image_data.append([asarray(curr_image), prod_id, url])

    enhancer = ImageEnhance.Contrast(img)
    curr_image = enhancer.enhance(0.5)

    all_image_data.append([asarray(curr_image), prod_id, url])

    enhancer = ImageEnhance.Contrast(img)
    curr_image = enhancer.enhance(1.5)

    all_image_data.append([asarray(curr_image), prod_id, url])

In [7]:
print(len(all_image_data))

18040


In [8]:
# Extract features of the images using resnet50

from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
import torchvision.transforms as transforms

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
model = nn.Sequential(*list(model.children())[:-1]) #remove softmax layer
model.eval()

transform = transforms.Compose([
    transforms.ToTensor(),  # Convert numpy array to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize if needed
])

In [9]:
# Extract features of the images using resnet50
import torch

with torch.no_grad():

    for i in range(len(all_image_data)):

        # get the image
        img = (all_image_data[i][0]).copy()

        # convert the image to tensor
        img = transform(img)

        # add a dimension to the image
        img = img.unsqueeze(0)

        # extract the features
        features = model(img)
        
        features = features.squeeze()

        # convert the features to numpy array
        features = features.numpy()

        # append the features to the list
        all_image_data[i].append(features)

In [10]:
# Normalize the obtained features for the images
print("Normalizing the features...")

for i in range(len(all_image_data)):
    all_image_data[i][3] = (all_image_data[i][3]-np.mean(all_image_data[i][3])) / (np.std(all_image_data[i][3])+1e-5)

Normalizing the features...


In [11]:
# Save the features to a file

with open('image_features.pkl', 'wb') as file:
    pickle.dump(all_image_data, file)