In [40]:
# PyTorch
import torch
import torchvision
from torchvision.transforms import Normalize, Resize, ToTensor, Compose
# For dislaying images
from PIL import Image
# import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
# Loading dataset
from datasets import load_dataset
# Transformers
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
# Matrix operations
import numpy as np

# Evaluation
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torchvision.datasets import DatasetFolder,ImageFolder
from torch.utils.data import random_split
from torch.utils.data import DataLoader

In [41]:
def get_device():
    import platform
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    # "mps" if platform.system() == 'Darwin' and torch.backends.mps.is_built() \
    #         else 
    return device

In [42]:
import os
data_path = 'photo'
if os.path.isdir(data_path): print("%s exists" %data_path)
else: 
    raise FileNotFoundError("%s directory does not exist. Please download the data." % data_path)

photo exists


In [43]:
class EarlyStopper(object):
    def __init__(self, num_trials):
        self.num_trials = num_trials
        self.trial_counter = 0
        self.best_loss = float('inf')

    def is_continuable(self, model, loss):
        if loss < self.best_loss:
            self.best_loss = loss
            self.trial_counter = 0
            return True
        elif self.trial_counter + 1 < self.num_trials:
            self.trial_counter += 1
            return True
        else:
            return False

### Data Loader

Image Checking -- skip when ran once

In [5]:
import os
from PIL import Image
folder_path = r'photo'
extensions = []
remove = []
for fldr in os.listdir(folder_path):
    sub_folder_path = os.path.join(folder_path, fldr)
    for file in os.listdir(sub_folder_path):
        file_path = os.path.join(sub_folder_path, file)
        print('** Path: {}  **'.format(file_path), end="\r", flush=True)
        try:
            im = Image.open(file_path)
        except:
            print('Remove 1 image')
            remove.append(file.split('.')[0])
        rgb_im = im.convert('RGB')
        if file.split('.')[1] not in extensions:
            extensions.append(file.split('.')[1])

Remove 1 image\photos\-BIybLxzoFt2d2zbYRcfHA.jpg  **
Remove 1 image\photos\-NGY_19QK2zq913HdiYc5A.jpg  **
Remove 1 image\photos\-YAvSvGUs2ugiJUvIRO6Jw.jpg  **
Remove 1 image\photos\-ZkmgGLJ7AJTjy96nocMNw.jpg  **
Remove 1 image\photos\0fac-NlXqfBO2pWRkmM9aw.jpg  **
Remove 1 image\photos\0TpeNZPs3Gu8s30KVXudcg.jpg  **
** Path: photo\photos\0wAjchhtkxi3k0-kwr1QJQ.jpg  **

In [44]:
remove = ['-BIybLxzoFt2d2zbYRcfHA',
 '-NGY_19QK2zq913HdiYc5A',
 '-YAvSvGUs2ugiJUvIRO6Jw',
 '-ZkmgGLJ7AJTjy96nocMNw',
 '0fac-NlXqfBO2pWRkmM9aw',
 '0TpeNZPs3Gu8s30KVXudcg',
 '1MOGQBWogR8oJr1WgERi9g',
 '1wd_eyhMrTqUmicDmn4_Kw',
 '2S78q98b_VpBD7vkrDE5-A',
 '43fHlHSYQ_79OBJW1aVUxA',
 '5q-sAvIPl0yNeuAbNBPM1g',
 '6bKuH4FOdaaPInF9NmlQHQ',
 '74upe0h6XxwgzqpdnAh_7Q',
 '7xcWPjcE4mxoQ1AjvvKJZg',
 '9BvYOtforBBP6MvvDogtmw',
 '9jBH61ndIcsheo6FtIHArA',
 '9RDbbAZB0HnL4hndCWB58w',
 '9X4YPM8nYFjf7hY8xUdc6Q',
 'AkiGRjaMKHdJyV7bdHsQjw',
 'amM65inTV6wvx0NNZN5qhg',
 'AMSyCOP3-Eb_ivNA8w1Vhw',
 'ARwqGQZaT0p-XpYYjMXgQg',
 'aUDiJhcFKt0exhyj4Q23Ow',
 'B7xR9CuhRpP52PoehQHVow',
 'bf3ymV0YgP7B6rEoriaU2w',
 'C6n0nKVbgLbYmxSiQ_bFsg',
 'c73YwNh1JsYR5Hz-u_bOrg',
 'CA9z96gGA4y9QOes2Y9eGw',
 'CBxmBYD_5CXIL_F-2PDqmA',
 'cNkUV0sInfh_Py5PP8SHtQ',
 'cwwoZcpqdu2MwdDusNyTdg',
 'DB7BlUpO4LAmC1lCN62hqg',
 'DMCTwC3UT2w5QzHOQoqBPw',
 'E7Wpzn-1fCnVJ8_zKpecPQ',
 'feUGw0P5byOq4U40C77tyQ',
 'gJH0d6Sut4eZDlbV0GCByg',
 'GPMWGVjuCsa6fadnZsEplw',
 'GWLmPwKeBnh2b_7Kv_LQ7w',
 'hChXG-gGWxzGvalse3EYmw',
 'hclqCX1FWcV_TtJJoI3BpQ',
 'hjEfal2a1DWRDu8_AUDLNg',
 'IB2ZjqjtS1W_DadQoPPdgg',
 'IExxMfr1h0bxw54jsanyKA',
 'IkGbGxI8IoOCuVsNB0VLrA',
 'IUsKp87a-v9Yhx6Ftg1m5A',
 'iX-8Xm2G7meRHUg8qhoL1A',
 'j5-4lzg23yGECBa6l1fyRQ',
 'JG5s_bvRF1cSWf1fk9lTbw',
 'JGpfPj8VEvnq1B-Xqr3w-A',
 'JoQ5xekjQUkj8rukJIzqgg',
 'jU-dKl2Ye4L_5x602yoctQ',
 'juDNZOOnkgG3QINFrulsAg',
 'JZZ716oX6_MqH6L_MkWK-A',
 'K6pfRNwGodm1m1gFVQlj-Q',
 'ke4ohxa93GJz0KH9H2kwsQ',
 'kjMBhxBXOUE7SSUQb-YQbw',
 'l2vR3PyVMF3pgIERdDEuiQ',
 'LhLfsQtYwJ5OmEzilubhXQ',
 'lrfy4UVIWtj0xwboLgUreQ',
 'LXT4hCf1lRyUeM4HDBaSvg',
 'l_rMdwgrvjm2PyHyXBcBTw',
 'm3oIKhKKCQD54y1E-dBKSw',
 'MduVueqYTBlEkX-axrh1ug',
 'MZj64XNUN6Og178-6XYR6g',
 'N6hL8FQ84A2DznF2S2Lp7g',
 'n6Q9vNuxz7786ESEfautxQ',
 'NfayhoTudVJQsEF-XlPyjw',
 'NKEFWvRriK-LvagPz2QRxw',
 'nKJ7yiPc0E_DJNtNxmCrhg',
 'O0bVFyP58TOEix6IjERXQA',
 'OK6HsALzFcBAUlrroKHZGg',
 'PFD3ykdI1WVhvZ8IX4PmLQ',
 'PjfJoBrEFgDrxiJy8nyatA',
 'Pk87_8Yndygr4LRUD_H7Hg',
 'pW1IPuTdLIUB61goirbXaA',
 'pY32hIagdxrL4Nsi959EQg',
 'QhATx1B1n8uf8C6siMNTfA',
 'qMlGILrsrzhMDxajNYiyIA',
 'QRUo4vqUu3X9V4eIqBpY8A',
 'qxSXsYMA3aWuAfigeqeOOQ',
 'RhC7TNmFvbR9GWrlrl5dsA',
 'RIeulJUzgemFugkkgg4qgA',
 'rIhUkEmP-j4NcQVW3YuPYQ',
 'rLafN9k3_AF5lZU0cs3LZg',
 'RLtBKD2rlfTaELWejmLBCA',
 'rrfwGSwt3eHxxypfu5PGTA',
 'tlp6LCLDsvL1GjO_kW_plQ',
 'TN4-gAea6ejAdZ-NzYXxng',
 'tSHz7RzlgceAItRejZ396A',
 'TvD36_DdnyCJuXV1SSt3_Q',
 't_sV6mI4oNvbvohhZAyeuA',
 'UG2JuFFa_WxhPEtMOtq-JQ',
 'VSekUmmsGZcX7KaPe_hXyw',
 'w5ABnSadHC8z1lthMQBaBQ',
 'W94rrCn0O5K1lkfD26m4tw',
 'WGmGujPl5BmR_fCUZnoe9w',
 'XX6ujA9CcB5s9y9wCy67-Q',
 'Y3lA41pnMkQNGfyREkf6SA',
 'yAf6R6OSgPo8-mmdDh8qIw',
 'ydm3g1wUWSxJnMPgHk2JhQ',
 'yFjqHyOaNFwzIWTV8EE9hg',
 'yhztPWh5IhaePpUQJNW-dQ',
 'ytJ4lihJrvyzMMRG-WwDNw',
 'YW1WMOkVbdFBrixDnKgoqA',
 'zTzdu2QqLozHpW_qYWF84w',
 '_exWW0g4Svg1Eo2YWsGzbg']

In [45]:
import ijson
from tqdm import tqdm
def load_business_rate(json_path):
    """
    This function loads the image targets from a csv file. It assumes that the csv file
    has a header row and that the first column contains the image path and all the subsequent
    columns contain the target values which are bundled together into a numpy array.
    """
    business_dict = {}
    with open(json_path,'r',encoding='utf-8') as f:
        f = ijson.parse(f,multiple_values = True)
        for line in ijson.items(f,""):
            key = line['business_id']
            value = line['stars']
            business_dict.setdefault(key,[]).append(value)

    for key in business_dict.keys():
        value = business_dict.get(key,None)
        business_dict[key] = [float(sum(value)/len(value))]      
    return business_dict

def load_image_rate(json_path, business_dict):
    with open(json_path,'r',encoding='utf-8') as f:
        f = ijson.parse(f,multiple_values = True)
        for line in ijson.items(f,''):
            if line['label'] == 'food':
                if line['business_id'] in business_dict.keys() and line['photo_id'] not in remove:
                    key = line['business_id']
                    business_dict.get(key,None).append(line['photo_id']+'.jpg')
    target_dict = {}
    for k,v in business_dict.items():
        rate = v.pop(0)
        for item in v:
            target_dict[item] = rate
    business_dict.clear()
    return target_dict
    
    




In [46]:

class RegressionImageFolder(torchvision.datasets.ImageFolder):
    """
    The regression image folder is a subclass of the ImageFolder class and is designed for 
    image regression tasks rather than image classification tasks. It takes in a dictionary
    that maps image paths to their target values.
    """
    def __init__(self, root, image_targets, *args, **kwargs):
        super().__init__(root, *args, **kwargs)
        paths, _ = zip(*self.imgs)
        prefix = paths[0].split('\\')[0] + '\\'
        
        # filtered_list = [item for item in my_list if item in my_dict]
        self.targets = list(image_targets.values())
        paths = [ prefix + k for k in image_targets.keys()]
        self.samples = self.imgs = list(zip(paths, self.targets))

def make_data(train_tfm):
    """
    Builds the train data loader
    """
    business_dict = load_business_rate('yelp_academic_dataset_review.json')
    targets = load_image_rate('photos.json', business_dict)
    data = RegressionImageFolder(
        'photo/', 
        image_targets= targets,
        loader=lambda x: Image.open(x),
        transform = train_tfm
    )
    # This constructs the dataloader that actually determins how images will be loaded in batches
    return data


In [47]:
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name) 

mu, sigma = processor.image_mean, processor.image_std #get default mu,sigma
size = processor.size

norm = Normalize(mean=mu, std=sigma) #normalize image pixels range to [-1,1]

# resize 3x32x32 to 3x224x224 -> convert to Pytorch tensor -> normalize
transf = Compose([
    Resize((224, 224)),
    ToTensor(),
    norm
]) 

In [48]:
data_set = make_data(transf)


len_data = len(data_set)
test_split = [int(len_data*0.8),len_data - int(len_data*0.8)]
trainset_data, test_data = random_split(dataset=data_set, lengths=test_split,generator=torch.Generator().manual_seed(0))

len_train = len(trainset_data)
val_split = [int(len_train*0.9),len_train - int(len_train*0.9)]
train_data, val_data = random_split(dataset=trainset_data, lengths=val_split,generator=torch.Generator().manual_seed(0))
input_dim = train_data[0][0].shape[0]

def load_data(batch_size):
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader
train_loader, val_loader, test_loader = load_data(32)

### Model training - Fine Tuning

In [49]:
model = ViTForImageClassification.from_pretrained(model_name, num_labels=1, ignore_mismatched_sizes=True)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [50]:
args = TrainingArguments(
    "photo_to_rate",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=1e-2,
    load_best_model_at_end=True,
    logging_dir='logs',
    remove_unused_columns=False,
    fp16 = True
)

In [51]:
def collate_fn(examples):
    pixels = torch.stack([example[0] for example in examples])
    labels = torch.tensor([example[1] for example in examples])
    return {"pixel_values": pixels, "labels": labels}

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return {'MAE': mean_absolute_error(predictions, labels)}

In [52]:

# data_collator = DataCollatorWithPadding(tokenizer=processor)
trainer = Trainer(
    model,
    args, 
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor
)
trainer.train()


  0%|          | 0/29175 [00:00<?, ?it/s]

{'loss': 0.9369, 'grad_norm': 13.859210968017578, 'learning_rate': 1.9660668380462726e-05, 'epoch': 0.05}
{'loss': 0.3678, 'grad_norm': 11.545600891113281, 'learning_rate': 1.9317909168808914e-05, 'epoch': 0.1}
{'loss': 0.3442, 'grad_norm': 18.741853713989258, 'learning_rate': 1.89751499571551e-05, 'epoch': 0.15}
{'loss': 0.3338, 'grad_norm': 9.862781524658203, 'learning_rate': 1.8632390745501287e-05, 'epoch': 0.21}
{'loss': 0.339, 'grad_norm': 7.769491195678711, 'learning_rate': 1.8289631533847475e-05, 'epoch': 0.26}
{'loss': 0.3061, 'grad_norm': 3.020697832107544, 'learning_rate': 1.794687232219366e-05, 'epoch': 0.31}
{'loss': 0.3039, 'grad_norm': 8.16762924194336, 'learning_rate': 1.7604113110539848e-05, 'epoch': 0.36}
{'loss': 0.2997, 'grad_norm': 7.172269344329834, 'learning_rate': 1.7261353898886033e-05, 'epoch': 0.41}
{'loss': 0.2907, 'grad_norm': 11.108847618103027, 'learning_rate': 1.691859468723222e-05, 'epoch': 0.46}
{'loss': 0.3073, 'grad_norm': 20.016578674316406, 'learnin

  0%|          | 0/2161 [00:00<?, ?it/s]

{'eval_loss': 0.2838090658187866, 'eval_MAE': 0.40679147839546204, 'eval_runtime': 101.5598, 'eval_samples_per_second': 85.112, 'eval_steps_per_second': 21.278, 'epoch': 1.0}
{'loss': 0.2672, 'grad_norm': 13.932455062866211, 'learning_rate': 1.3149614395886891e-05, 'epoch': 1.03}
{'loss': 0.2161, 'grad_norm': 3.863142490386963, 'learning_rate': 1.2806855184233077e-05, 'epoch': 1.08}
{'loss': 0.2385, 'grad_norm': 6.299255847930908, 'learning_rate': 1.2464095972579264e-05, 'epoch': 1.13}
{'loss': 0.2272, 'grad_norm': 9.900640487670898, 'learning_rate': 1.212133676092545e-05, 'epoch': 1.18}
{'loss': 0.2298, 'grad_norm': 7.394041061401367, 'learning_rate': 1.1778577549271637e-05, 'epoch': 1.23}
{'loss': 0.2396, 'grad_norm': 5.496857643127441, 'learning_rate': 1.1436503856041132e-05, 'epoch': 1.29}
{'loss': 0.2307, 'grad_norm': 16.216854095458984, 'learning_rate': 1.1093744644387318e-05, 'epoch': 1.34}
{'loss': 0.2268, 'grad_norm': 4.526978015899658, 'learning_rate': 1.0750985432733505e-05,

  0%|          | 0/2161 [00:00<?, ?it/s]

{'eval_loss': 0.2693431079387665, 'eval_MAE': 0.39130038022994995, 'eval_runtime': 78.9377, 'eval_samples_per_second': 109.504, 'eval_steps_per_second': 27.376, 'epoch': 2.0}
{'loss': 0.2233, 'grad_norm': 4.629077434539795, 'learning_rate': 6.639245929734362e-06, 'epoch': 2.01}
{'loss': 0.1606, 'grad_norm': 13.096073150634766, 'learning_rate': 6.296486718080549e-06, 'epoch': 2.06}
{'loss': 0.1601, 'grad_norm': 5.3261003494262695, 'learning_rate': 5.954413024850043e-06, 'epoch': 2.11}
{'loss': 0.1681, 'grad_norm': 7.779398441314697, 'learning_rate': 5.6116538131962305e-06, 'epoch': 2.16}
{'loss': 0.1708, 'grad_norm': 7.1827311515808105, 'learning_rate': 5.268894601542417e-06, 'epoch': 2.21}
{'loss': 0.1642, 'grad_norm': 13.48320484161377, 'learning_rate': 4.926135389888604e-06, 'epoch': 2.26}
{'loss': 0.1577, 'grad_norm': 3.9348652362823486, 'learning_rate': 4.584061696658098e-06, 'epoch': 2.31}
{'loss': 0.1595, 'grad_norm': 4.423639297485352, 'learning_rate': 4.241302485004285e-06, 'ep

  0%|          | 0/2161 [00:00<?, ?it/s]

{'eval_loss': 0.27013716101646423, 'eval_MAE': 0.3962051570415497, 'eval_runtime': 78.3988, 'eval_samples_per_second': 110.257, 'eval_steps_per_second': 27.564, 'epoch': 3.0}
{'train_runtime': 4693.0671, 'train_samples_per_second': 49.728, 'train_steps_per_second': 6.217, 'train_loss': 0.24124454710763443, 'epoch': 3.0}


TrainOutput(global_step=29175, training_loss=0.24124454710763443, metrics={'train_runtime': 4693.0671, 'train_samples_per_second': 49.728, 'train_steps_per_second': 6.217, 'train_loss': 0.24124454710763443, 'epoch': 3.0})

In [53]:
outputs = trainer.predict(test_data)
print(outputs.metrics)

  0%|          | 0/5403 [00:00<?, ?it/s]

{'test_loss': 0.26799851655960083, 'test_MAE': 0.3889068067073822, 'test_runtime': 747.8915, 'test_samples_per_second': 28.895, 'test_steps_per_second': 7.224}


In [64]:
import pandas as pd
y_true = outputs.label_ids
y_true = np.round(y_true)
y_pred = outputs.predictions.squeeze()
y_pred = np.round(y_pred)

data = {'y_true': y_true, 'y_pred': y_pred}
df = pd.DataFrame(data)

# Save the dataframe as a CSV file
df.to_csv('output2.csv', index=False)

# Create the confusion matrix
cm = confusion_matrix(y_true, y_pred)

print(cm)

[4. 4. 4. ... 4. 4. 4.]
[[    0   141    18     8     0]
 [    0   915   488   258     0]
 [    0   364  1298  2742    17]
 [    0    57   485 12342   178]
 [    0     1    22  2124   152]]


In [58]:
def normalize_confusion_matrix(confusion_matrix):
    # Compute the sum of each row
    row_sum = np.sum(confusion_matrix, axis=1, keepdims=True)
    # Divide each element in the confusion matrix by its respective row sum
    normalized_matrix = confusion_matrix / row_sum
    np.set_printoptions(suppress=True)
    return normalized_matrix
normalized_by_row = normalize_confusion_matrix(cm)
print(normalized_by_row)

[[0.         0.84431138 0.10778443 0.04790419 0.        ]
 [0.         0.55087297 0.29379892 0.15532812 0.        ]
 [0.         0.08233431 0.29359873 0.62022167 0.00384528]
 [0.         0.0043638  0.03713061 0.94487827 0.01362732]
 [0.         0.00043497 0.00956938 0.92387995 0.0661157 ]]
