In [1]:
# PyTorch
import torch
import torchvision
from torchvision.transforms import Normalize, Resize, ToTensor, Compose
# For dislaying images
from PIL import Image
# import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
# Loading dataset
from datasets import load_dataset
# Transformers
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
# Matrix operations
import numpy as np

# Evaluation
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torchvision.datasets import DatasetFolder,ImageFolder
from torch.utils.data import random_split
from torch.utils.data import DataLoader

In [2]:
def get_device():
    import platform
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    # "mps" if platform.system() == 'Darwin' and torch.backends.mps.is_built() \
    #         else 
    return device

In [3]:
import os
data_path = 'photo'
if os.path.isdir(data_path): print("%s exists" %data_path)
else: 
    raise FileNotFoundError("%s directory does not exist. Please download the data." % data_path)

photo exists


In [4]:
class EarlyStopper(object):
    def __init__(self, num_trials):
        self.num_trials = num_trials
        self.trial_counter = 0
        self.best_loss = float('inf')

    def is_continuable(self, model, loss):
        if loss < self.best_loss:
            self.best_loss = loss
            self.trial_counter = 0
            return True
        elif self.trial_counter + 1 < self.num_trials:
            self.trial_counter += 1
            return True
        else:
            return False

### Data Loader

Image Checking -- skip when ran once

In [5]:
import os
from PIL import Image
folder_path = r'photo'
extensions = []
remove = []
for fldr in os.listdir(folder_path):
    sub_folder_path = os.path.join(folder_path, fldr)
    for file in os.listdir(sub_folder_path):
        file_path = os.path.join(sub_folder_path, file)
        print('** Path: {}  **'.format(file_path), end="\r", flush=True)
        try:
            im = Image.open(file_path)
        except:
            print('Remove 1 image')
            remove.append(file.split('.')[0])
        rgb_im = im.convert('RGB')
        if file.split('.')[1] not in extensions:
            extensions.append(file.split('.')[1])

Remove 1 image\photos\-BIybLxzoFt2d2zbYRcfHA.jpg  **
Remove 1 image\photos\-NGY_19QK2zq913HdiYc5A.jpg  **
Remove 1 image\photos\-YAvSvGUs2ugiJUvIRO6Jw.jpg  **
Remove 1 image\photos\-ZkmgGLJ7AJTjy96nocMNw.jpg  **
Remove 1 image\photos\0fac-NlXqfBO2pWRkmM9aw.jpg  **
Remove 1 image\photos\0TpeNZPs3Gu8s30KVXudcg.jpg  **
Remove 1 image\photos\1MOGQBWogR8oJr1WgERi9g.jpg  **
Remove 1 image\photos\1wd_eyhMrTqUmicDmn4_Kw.jpg  **
Remove 1 image\photos\2S78q98b_VpBD7vkrDE5-A.jpg  **
Remove 1 image\photos\43fHlHSYQ_79OBJW1aVUxA.jpg  **
Remove 1 image\photos\5q-sAvIPl0yNeuAbNBPM1g.jpg  **
Remove 1 image\photos\6bKuH4FOdaaPInF9NmlQHQ.jpg  **
Remove 1 image\photos\74upe0h6XxwgzqpdnAh_7Q.jpg  **
Remove 1 image\photos\7xcWPjcE4mxoQ1AjvvKJZg.jpg  **
Remove 1 image\photos\9BvYOtforBBP6MvvDogtmw.jpg  **
Remove 1 image\photos\9jBH61ndIcsheo6FtIHArA.jpg  **
Remove 1 image\photos\9RDbbAZB0HnL4hndCWB58w.jpg  **
Remove 1 image\photos\9X4YPM8nYFjf7hY8xUdc6Q.jpg  **
Remove 1 image\photos\AkiGRjaMKHdJyV7bdHsQjw.j

In [5]:
remove = ['-BIybLxzoFt2d2zbYRcfHA',
 '-NGY_19QK2zq913HdiYc5A',
 '-YAvSvGUs2ugiJUvIRO6Jw',
 '-ZkmgGLJ7AJTjy96nocMNw',
 '0fac-NlXqfBO2pWRkmM9aw',
 '0TpeNZPs3Gu8s30KVXudcg',
 '1MOGQBWogR8oJr1WgERi9g',
 '1wd_eyhMrTqUmicDmn4_Kw',
 '2S78q98b_VpBD7vkrDE5-A',
 '43fHlHSYQ_79OBJW1aVUxA',
 '5q-sAvIPl0yNeuAbNBPM1g',
 '6bKuH4FOdaaPInF9NmlQHQ',
 '74upe0h6XxwgzqpdnAh_7Q',
 '7xcWPjcE4mxoQ1AjvvKJZg',
 '9BvYOtforBBP6MvvDogtmw',
 '9jBH61ndIcsheo6FtIHArA',
 '9RDbbAZB0HnL4hndCWB58w',
 '9X4YPM8nYFjf7hY8xUdc6Q',
 'AkiGRjaMKHdJyV7bdHsQjw',
 'amM65inTV6wvx0NNZN5qhg',
 'AMSyCOP3-Eb_ivNA8w1Vhw',
 'ARwqGQZaT0p-XpYYjMXgQg',
 'aUDiJhcFKt0exhyj4Q23Ow',
 'B7xR9CuhRpP52PoehQHVow',
 'bf3ymV0YgP7B6rEoriaU2w',
 'C6n0nKVbgLbYmxSiQ_bFsg',
 'c73YwNh1JsYR5Hz-u_bOrg',
 'CA9z96gGA4y9QOes2Y9eGw',
 'CBxmBYD_5CXIL_F-2PDqmA',
 'cNkUV0sInfh_Py5PP8SHtQ',
 'cwwoZcpqdu2MwdDusNyTdg',
 'DB7BlUpO4LAmC1lCN62hqg',
 'DMCTwC3UT2w5QzHOQoqBPw',
 'E7Wpzn-1fCnVJ8_zKpecPQ',
 'feUGw0P5byOq4U40C77tyQ',
 'gJH0d6Sut4eZDlbV0GCByg',
 'GPMWGVjuCsa6fadnZsEplw',
 'GWLmPwKeBnh2b_7Kv_LQ7w',
 'hChXG-gGWxzGvalse3EYmw',
 'hclqCX1FWcV_TtJJoI3BpQ',
 'hjEfal2a1DWRDu8_AUDLNg',
 'IB2ZjqjtS1W_DadQoPPdgg',
 'IExxMfr1h0bxw54jsanyKA',
 'IkGbGxI8IoOCuVsNB0VLrA',
 'IUsKp87a-v9Yhx6Ftg1m5A',
 'iX-8Xm2G7meRHUg8qhoL1A',
 'j5-4lzg23yGECBa6l1fyRQ',
 'JG5s_bvRF1cSWf1fk9lTbw',
 'JGpfPj8VEvnq1B-Xqr3w-A',
 'JoQ5xekjQUkj8rukJIzqgg',
 'jU-dKl2Ye4L_5x602yoctQ',
 'juDNZOOnkgG3QINFrulsAg',
 'JZZ716oX6_MqH6L_MkWK-A',
 'K6pfRNwGodm1m1gFVQlj-Q',
 'ke4ohxa93GJz0KH9H2kwsQ',
 'kjMBhxBXOUE7SSUQb-YQbw',
 'l2vR3PyVMF3pgIERdDEuiQ',
 'LhLfsQtYwJ5OmEzilubhXQ',
 'lrfy4UVIWtj0xwboLgUreQ',
 'LXT4hCf1lRyUeM4HDBaSvg',
 'l_rMdwgrvjm2PyHyXBcBTw',
 'm3oIKhKKCQD54y1E-dBKSw',
 'MduVueqYTBlEkX-axrh1ug',
 'MZj64XNUN6Og178-6XYR6g',
 'N6hL8FQ84A2DznF2S2Lp7g',
 'n6Q9vNuxz7786ESEfautxQ',
 'NfayhoTudVJQsEF-XlPyjw',
 'NKEFWvRriK-LvagPz2QRxw',
 'nKJ7yiPc0E_DJNtNxmCrhg',
 'O0bVFyP58TOEix6IjERXQA',
 'OK6HsALzFcBAUlrroKHZGg',
 'PFD3ykdI1WVhvZ8IX4PmLQ',
 'PjfJoBrEFgDrxiJy8nyatA',
 'Pk87_8Yndygr4LRUD_H7Hg',
 'pW1IPuTdLIUB61goirbXaA',
 'pY32hIagdxrL4Nsi959EQg',
 'QhATx1B1n8uf8C6siMNTfA',
 'qMlGILrsrzhMDxajNYiyIA',
 'QRUo4vqUu3X9V4eIqBpY8A',
 'qxSXsYMA3aWuAfigeqeOOQ',
 'RhC7TNmFvbR9GWrlrl5dsA',
 'RIeulJUzgemFugkkgg4qgA',
 'rIhUkEmP-j4NcQVW3YuPYQ',
 'rLafN9k3_AF5lZU0cs3LZg',
 'RLtBKD2rlfTaELWejmLBCA',
 'rrfwGSwt3eHxxypfu5PGTA',
 'tlp6LCLDsvL1GjO_kW_plQ',
 'TN4-gAea6ejAdZ-NzYXxng',
 'tSHz7RzlgceAItRejZ396A',
 'TvD36_DdnyCJuXV1SSt3_Q',
 't_sV6mI4oNvbvohhZAyeuA',
 'UG2JuFFa_WxhPEtMOtq-JQ',
 'VSekUmmsGZcX7KaPe_hXyw',
 'w5ABnSadHC8z1lthMQBaBQ',
 'W94rrCn0O5K1lkfD26m4tw',
 'WGmGujPl5BmR_fCUZnoe9w',
 'XX6ujA9CcB5s9y9wCy67-Q',
 'Y3lA41pnMkQNGfyREkf6SA',
 'yAf6R6OSgPo8-mmdDh8qIw',
 'ydm3g1wUWSxJnMPgHk2JhQ',
 'yFjqHyOaNFwzIWTV8EE9hg',
 'yhztPWh5IhaePpUQJNW-dQ',
 'ytJ4lihJrvyzMMRG-WwDNw',
 'YW1WMOkVbdFBrixDnKgoqA',
 'zTzdu2QqLozHpW_qYWF84w',
 '_exWW0g4Svg1Eo2YWsGzbg']

In [6]:
import ijson
from tqdm import tqdm
def load_business_rate(json_path):
    """
    This function loads the image targets from a csv file. It assumes that the csv file
    has a header row and that the first column contains the image path and all the subsequent
    columns contain the target values which are bundled together into a numpy array.
    """
    business_dict = {}
    with open(json_path,'r') as f:
        f = ijson.parse(f,multiple_values = True)
        for line in ijson.items(f,""):
            key = line['business_id']
            value = line['stars']
            business_dict.setdefault(key,[]).append(value)

    for key in business_dict.keys():
        value = business_dict.get(key,None)
        business_dict[key] = [round(sum(value)/len(value))]      
    return business_dict

def load_image_rate(json_path, business_dict):
    with open(json_path,'r') as f:
        f = ijson.parse(f,multiple_values = True)
        for line in ijson.items(f,''):
            if line['label'] == 'food':
                if line['business_id'] in business_dict.keys() and line['photo_id'] not in remove:
                    key = line['business_id']
                    business_dict.get(key,None).append(line['photo_id']+'.jpg')
    target_dict = {}
    for k,v in business_dict.items():
        rate = v.pop(0)
        for item in v:
            target_dict[item] = rate
    business_dict.clear()
    return target_dict
    
    




In [7]:

class RegressionImageFolder(torchvision.datasets.ImageFolder):
    """
    The regression image folder is a subclass of the ImageFolder class and is designed for 
    image regression tasks rather than image classification tasks. It takes in a dictionary
    that maps image paths to their target values.
    """
    def __init__(self, root, image_targets, *args, **kwargs):
        super().__init__(root, *args, **kwargs)
        paths, _ = zip(*self.imgs)
        prefix = paths[0].split('\\')[0] + '\\'
        
        # filtered_list = [item for item in my_list if item in my_dict]
        self.targets = list(image_targets.values())
        paths = [ prefix + k for k in image_targets.keys()]
        self.samples = self.imgs = list(zip(paths, self.targets))

def make_data(train_tfm):
    """
    Builds the train data loader
    """
    business_dict = load_business_rate('yelp_academic_dataset_review.json')
    targets = load_image_rate('photos.json', business_dict)
    data = RegressionImageFolder(
        'photo/', 
        image_targets= targets,
        loader=lambda x: Image.open(x),
        transform = train_tfm
    )
    # This constructs the dataloader that actually determins how images will be loaded in batches
    return data


In [8]:
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name) 

mu, sigma = processor.image_mean, processor.image_std #get default mu,sigma
size = processor.size

norm = Normalize(mean=mu, std=sigma) #normalize image pixels range to [-1,1]

# resize 3x32x32 to 3x224x224 -> convert to Pytorch tensor -> normalize
transf = Compose([
    Resize((224, 224)),
    ToTensor(),
    norm
]) 

In [9]:
data_set = make_data(transf)


len_data = len(data_set)
test_split = [int(len_data*0.8),len_data - int(len_data*0.8)]
trainset_data, test_data = random_split(dataset=data_set, lengths=test_split,generator=torch.Generator().manual_seed(0))

len_train = len(trainset_data)
val_split = [int(len_train*0.9),len_train - int(len_train*0.9)]
train_data, val_data = random_split(dataset=trainset_data, lengths=val_split,generator=torch.Generator().manual_seed(0))
input_dim = train_data[0][0].shape[0]

def load_data(batch_size):
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader
train_loader, val_loader, test_loader = load_data(32)

### Model training - Fine Tuning

In [10]:
itos = dict((i,i+1) for i in range(5))
stoi = dict((i+1,i) for i in range(5))
model = ViTForImageClassification.from_pretrained(model_name, num_labels=6, ignore_mismatched_sizes=True, id2label=itos, label2id=stoi)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([6]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([6, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
args = TrainingArguments(
    "photo_to_rate",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=1e-2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir='logs',
    remove_unused_columns=False,
    fp16 = True
)

In [16]:
def collate_fn(examples):
    pixels = torch.stack([example[0] for example in examples])
    labels = torch.tensor([example[1] for example in examples])
    return {"pixel_values": pixels, "labels": labels}

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return {'accuracy': accuracy_score(predictions, labels)}

In [17]:
remove

['-BIybLxzoFt2d2zbYRcfHA',
 '-NGY_19QK2zq913HdiYc5A',
 '-YAvSvGUs2ugiJUvIRO6Jw',
 '-ZkmgGLJ7AJTjy96nocMNw',
 '0fac-NlXqfBO2pWRkmM9aw',
 '0TpeNZPs3Gu8s30KVXudcg',
 '1MOGQBWogR8oJr1WgERi9g',
 '1wd_eyhMrTqUmicDmn4_Kw',
 '2S78q98b_VpBD7vkrDE5-A',
 '43fHlHSYQ_79OBJW1aVUxA',
 '5q-sAvIPl0yNeuAbNBPM1g',
 '6bKuH4FOdaaPInF9NmlQHQ',
 '74upe0h6XxwgzqpdnAh_7Q',
 '7xcWPjcE4mxoQ1AjvvKJZg',
 '9BvYOtforBBP6MvvDogtmw',
 '9jBH61ndIcsheo6FtIHArA',
 '9RDbbAZB0HnL4hndCWB58w',
 '9X4YPM8nYFjf7hY8xUdc6Q',
 'AkiGRjaMKHdJyV7bdHsQjw',
 'amM65inTV6wvx0NNZN5qhg',
 'AMSyCOP3-Eb_ivNA8w1Vhw',
 'ARwqGQZaT0p-XpYYjMXgQg',
 'aUDiJhcFKt0exhyj4Q23Ow',
 'B7xR9CuhRpP52PoehQHVow',
 'bf3ymV0YgP7B6rEoriaU2w',
 'C6n0nKVbgLbYmxSiQ_bFsg',
 'c73YwNh1JsYR5Hz-u_bOrg',
 'CA9z96gGA4y9QOes2Y9eGw',
 'CBxmBYD_5CXIL_F-2PDqmA',
 'cNkUV0sInfh_Py5PP8SHtQ',
 'cwwoZcpqdu2MwdDusNyTdg',
 'DB7BlUpO4LAmC1lCN62hqg',
 'DMCTwC3UT2w5QzHOQoqBPw',
 'E7Wpzn-1fCnVJ8_zKpecPQ',
 'feUGw0P5byOq4U40C77tyQ',
 'gJH0d6Sut4eZDlbV0GCByg',
 'GPMWGVjuCsa6fadnZsEplw',
 

In [18]:
from transformers import DataCollatorWithPadding

# data_collator = DataCollatorWithPadding(tokenizer=processor)
trainer = Trainer(
    model,
    args, 
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor
)
trainer.train()


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/29175 [00:00<?, ?it/s]

{'loss': 0.7118, 'grad_norm': 4.9719038009643555, 'learning_rate': 1.9658611825192804e-05, 'epoch': 0.05}
{'loss': 0.674, 'grad_norm': 3.05021595954895, 'learning_rate': 1.931585261353899e-05, 'epoch': 0.1}
{'loss': 0.6752, 'grad_norm': 4.472777843475342, 'learning_rate': 1.8973093401885177e-05, 'epoch': 0.15}
{'loss': 0.6426, 'grad_norm': 4.0106964111328125, 'learning_rate': 1.8630334190231365e-05, 'epoch': 0.21}
{'loss': 0.6262, 'grad_norm': 3.5623209476470947, 'learning_rate': 1.8288260497000857e-05, 'epoch': 0.26}
{'loss': 0.6435, 'grad_norm': 3.9394114017486572, 'learning_rate': 1.7945501285347045e-05, 'epoch': 0.31}
{'loss': 0.6337, 'grad_norm': 5.440028190612793, 'learning_rate': 1.760274207369323e-05, 'epoch': 0.36}
{'loss': 0.6273, 'grad_norm': 2.9689371585845947, 'learning_rate': 1.7259982862039418e-05, 'epoch': 0.41}
{'loss': 0.6019, 'grad_norm': 3.333549737930298, 'learning_rate': 1.6917223650385606e-05, 'epoch': 0.46}
{'loss': 0.6005, 'grad_norm': 10.020845413208008, 'lear

  0%|          | 0/2161 [00:00<?, ?it/s]

{'eval_loss': 1.0246981382369995, 'eval_accuracy': 0.6619620546043499, 'eval_runtime': 81.4434, 'eval_samples_per_second': 106.135, 'eval_steps_per_second': 26.534, 'epoch': 1.0}
{'loss': 0.4722, 'grad_norm': 8.607462882995605, 'learning_rate': 1.3148928877463582e-05, 'epoch': 1.03}
{'loss': 0.3835, 'grad_norm': 3.5362110137939453, 'learning_rate': 1.2806855184233077e-05, 'epoch': 1.08}
{'loss': 0.3921, 'grad_norm': 8.042865753173828, 'learning_rate': 1.2464095972579264e-05, 'epoch': 1.13}
{'loss': 0.3763, 'grad_norm': 13.1644926071167, 'learning_rate': 1.212133676092545e-05, 'epoch': 1.18}
{'loss': 0.4041, 'grad_norm': 9.167098045349121, 'learning_rate': 1.1778577549271637e-05, 'epoch': 1.23}
{'loss': 0.4028, 'grad_norm': 13.384492874145508, 'learning_rate': 1.1435818337617823e-05, 'epoch': 1.29}
{'loss': 0.3795, 'grad_norm': 12.04263687133789, 'learning_rate': 1.1093059125964011e-05, 'epoch': 1.34}
{'loss': 0.4048, 'grad_norm': 6.883047103881836, 'learning_rate': 1.0750299914310198e-

  0%|          | 0/2161 [00:00<?, ?it/s]

{'eval_loss': 1.1215083599090576, 'eval_accuracy': 0.6602267468764461, 'eval_runtime': 102.212, 'eval_samples_per_second': 84.569, 'eval_steps_per_second': 21.142, 'epoch': 2.0}
{'loss': 0.5317, 'grad_norm': 4.492599010467529, 'learning_rate': 6.638560411311055e-06, 'epoch': 2.01}
{'loss': 0.279, 'grad_norm': 2.145479440689087, 'learning_rate': 6.295801199657241e-06, 'epoch': 2.06}
{'loss': 0.2684, 'grad_norm': 2.728562116622925, 'learning_rate': 5.953041988003429e-06, 'epoch': 2.11}
{'loss': 0.3165, 'grad_norm': 8.011964797973633, 'learning_rate': 5.610282776349615e-06, 'epoch': 2.16}
{'loss': 0.2917, 'grad_norm': 0.8533404469490051, 'learning_rate': 5.268209083119109e-06, 'epoch': 2.21}
{'loss': 0.3085, 'grad_norm': 3.2253170013427734, 'learning_rate': 4.925449871465296e-06, 'epoch': 2.26}
{'loss': 0.2908, 'grad_norm': 1.9137020111083984, 'learning_rate': 4.582690659811483e-06, 'epoch': 2.31}
{'loss': 0.2929, 'grad_norm': 9.928186416625977, 'learning_rate': 4.23993144815767e-06, 'epo

  0%|          | 0/2161 [00:00<?, ?it/s]

{'eval_loss': 1.6411774158477783, 'eval_accuracy': 0.6525913928736696, 'eval_runtime': 87.0781, 'eval_samples_per_second': 99.267, 'eval_steps_per_second': 24.817, 'epoch': 3.0}
{'train_runtime': 4584.5544, 'train_samples_per_second': 50.905, 'train_steps_per_second': 6.364, 'train_loss': 0.44653482759111235, 'epoch': 3.0}


TrainOutput(global_step=29175, training_loss=0.44653482759111235, metrics={'train_runtime': 4584.5544, 'train_samples_per_second': 50.905, 'train_steps_per_second': 6.364, 'train_loss': 0.44653482759111235, 'epoch': 3.0})

In [19]:
outputs = trainer.predict(test_data)
print(outputs.metrics)

  0%|          | 0/5403 [00:00<?, ?it/s]

{'test_loss': 1.0347788333892822, 'test_accuracy': 0.6640444238778344, 'test_runtime': 734.6712, 'test_samples_per_second': 29.415, 'test_steps_per_second': 7.354}


### print confusion matrix

In [31]:
import pandas to pd
y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)
labels = np.array(itos.values())

data = {'y_true': y_true, 'y_pred': y_pred}
df = pd.DataFrame(data)

# Save the dataframe as a CSV file
df.to_csv('output.csv', index=False)


In [32]:
# rating range: 1 to 5
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_true, y_pred)
print(cm)

[[    0   150    11     6     0]
 [    0  1073   371   197    20]
 [    0   480  1575  2232   134]
 [    0    87   956 11224   795]
 [    0     3    73  1745   478]]


In [2]:
import numpy as np
def normalize_confusion_matrix(confusion_matrix):
    row_sum = np.sum(confusion_matrix, axis=1, keepdims=True)
    
    normalized_matrix = confusion_matrix / row_sum
    np.set_printoptions(suppress=True)
    return normalized_matrix
normalized_by_row = normalize_confusion_matrix(cm)
print(normalized_by_row)

[[0.         0.89820359 0.06586826 0.03592814 0.        ]
 [0.         0.64599639 0.22335942 0.11860325 0.01204094]
 [0.         0.10857272 0.35625424 0.50486315 0.03030988]
 [0.         0.00666054 0.0731894  0.85928648 0.06086357]
 [0.         0.00130492 0.03175294 0.75902566 0.20791649]]
