In [18]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
from torchvision.models import resnet18, ResNet18_Weights
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import warnings
import pandas as pd
warnings.filterwarnings("ignore")
import re
from tqdm import tqdm
from PIL import Image

df = pd.read_csv('kaggle_3m\data.csv')
missing = df[df['age_at_initial_pathologic'].isna()]['Patient']
missing

109    TCGA_HT_A61B
Name: Patient, dtype: object

In [23]:
def get_images(dir):
    images = []
    ids = []

    for subdir in os.listdir(dir):
        path = os.path.join(dir, subdir)
        pattern = r"TCGA_(CS|DU|FG|HT|EZ)_(\w{4})"
        match = re.search(pattern, path)

        if match:
            ids.append(match.group(2))

        if path.startswith("kaggle_3m\TCGA"):
            for image_name in os.listdir(path):
                
                if "mask" in image_name:
                    continue
                elif "TCGA_HT_A61B" in image_name:
                    continue
                else:
                    images.append(os.path.join(path, image_name))
    return images, ids

def get_labels(images, ids):
    ids = [id for image in images for id in ids if  id in image]
    df['num_id'] = df['Patient'].str.extract(r'([^_]+)$')
    labels = []
    for id in ids:
        age = df.loc[df['num_id'] == id, 'age_at_initial_pathologic'].values
        labels.append(age)
    return labels

class CustomImageDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
            
        if self.transform:
            image = self.transform(image)
        
        return image, label

def evaluate(model, loader, device):
    model.eval()
    total_error = 0
    count = 0
    with torch.no_grad():  
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images) 
            total_error += torch.sum(torch.abs(outputs - labels))
            count += images.size(0)
    
    return total_error / count


In [21]:
images, ids = get_images('kaggle_3m')
labels = get_labels(images, ids)
for i in range(len(images)):
    images[i] = cv2.imread(images[i])
    images[i] = np.array(images[i])

transform = transforms.Compose([
    transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),  
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

dataset = CustomImageDataset(images, labels, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

size = len(loader.dataset)
train_size = int(0.7 * size)
val_size = int(0.15 * size)
test_size = size - train_size - val_size

train_set, val_set, test_set = random_split(loader.dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=False)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False)

weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
model.fc = nn.Linear(model.fc.in_features, 1)

device = torch.device("cpu")
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 20


for epoch in range(epochs):
    model.train() 
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()  
        outputs = model(images)  
        loss = criterion(outputs.float(), labels.float())

        loss.backward()  
        optimizer.step()  

    MAE = evaluate(model, val_loader, device)
    print(f"Epoch {epoch+1}, Mean Absolute Error: {MAE.item()}")

Epoch 1: 100%|██████████| 168/168 [07:03<00:00,  2.52s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[ 7.4971],
        [ 7.4842],
        [11.2821],
        [15.2018],
        [10.3912],
        [17.8211],
        [15.5206],
        [11.0833],
        [13.8383],
        [12.1444],
        [17.8205],
        [17.0349],
        [14.8410],
        [ 8.5709],
        [14.3495],
        [ 9.6847]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[18.2196],
        [17.5698],
        [18.4079],
        [14.1279],
        [-0.9401],
        [13.9089],
        [14.2185],
        [ 8.5673],
        [10.3382],


Epoch 2: 100%|██████████| 168/168 [05:53<00:00,  2.10s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[16.0416],
        [21.5887],
        [29.2023],
        [23.3826],
        [27.5054],
        [27.6937],
        [15.3377],
        [14.2564],
        [17.3841],
        [16.9023],
        [30.5027],
        [31.3137],
        [27.8941],
        [19.6170],
        [25.0640],
        [17.7327]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[23.4410],
        [28.8285],
        [34.6970],
        [31.0158],
        [ 2.8061],
        [24.2905],
        [28.0649],
        [12.6171],
        [13.1885],


Epoch 3: 100%|██████████| 168/168 [04:55<00:00,  1.76s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[17.8757],
        [22.8214],
        [28.7476],
        [32.9252],
        [28.0644],
        [38.0480],
        [26.1817],
        [18.8417],
        [31.3751],
        [26.3522],
        [41.2781],
        [41.0300],
        [42.1501],
        [19.6883],
        [33.7493],
        [27.2586]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[35.3237],
        [41.7294],
        [46.1614],
        [38.1219],
        [ 8.8949],
        [28.8493],
        [33.5083],
        [15.3653],
        [18.4179],


Epoch 4: 100%|██████████| 168/168 [05:24<00:00,  1.93s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[22.2657],
        [20.2407],
        [31.3115],
        [38.2588],
        [42.2897],
        [41.6911],
        [26.9718],
        [26.2436],
        [26.4628],
        [33.5861],
        [44.0590],
        [46.8528],
        [54.1335],
        [24.1797],
        [33.5227],
        [25.6704]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[40.6181],
        [48.6621],
        [49.7279],
        [43.2929],
        [15.6328],
        [27.5788],
        [30.9021],
        [20.3652],
        [21.9212],


Epoch 5: 100%|██████████| 168/168 [04:42<00:00,  1.68s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[26.1989],
        [27.9994],
        [35.5803],
        [41.3492],
        [42.6392],
        [50.0074],
        [32.9436],
        [25.8218],
        [30.2631],
        [39.9504],
        [49.7715],
        [59.0713],
        [67.9099],
        [30.2616],
        [35.3824],
        [32.4715]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[47.0491],
        [59.1991],
        [59.3790],
        [48.3663],
        [21.9813],
        [38.6262],
        [36.8505],
        [27.9328],
        [26.3845],


Epoch 6: 100%|██████████| 168/168 [06:33<00:00,  2.34s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[30.5909],
        [31.8839],
        [37.3545],
        [43.5288],
        [49.6371],
        [55.4143],
        [34.1340],
        [33.0792],
        [35.9492],
        [46.1488],
        [49.7113],
        [58.8851],
        [60.3860],
        [34.8871],
        [41.8715],
        [38.6747]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[50.2321],
        [59.3404],
        [58.4340],
        [55.1158],
        [28.7202],
        [40.3199],
        [43.8438],
        [28.5224],
        [28.9972],


Epoch 7: 100%|██████████| 168/168 [06:20<00:00,  2.26s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[22.6017],
        [26.9203],
        [33.5265],
        [38.7128],
        [39.7987],
        [47.8857],
        [25.0019],
        [27.1895],
        [25.1196],
        [33.5246],
        [44.7998],
        [51.1234],
        [58.2734],
        [29.3480],
        [38.3008],
        [32.3700]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[48.2861],
        [50.9812],
        [54.5792],
        [44.0761],
        [21.1925],
        [32.7175],
        [31.5416],
        [25.3103],
        [25.4192],


Epoch 8: 100%|██████████| 168/168 [06:10<00:00,  2.21s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[30.1318],
        [30.5008],
        [43.0132],
        [49.6539],
        [51.4916],
        [54.8447],
        [37.0628],
        [29.8008],
        [31.9482],
        [42.7677],
        [56.4077],
        [63.5028],
        [60.6255],
        [37.1454],
        [40.3124],
        [37.6469]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[56.5596],
        [64.6008],
        [63.1420],
        [53.8671],
        [27.0856],
        [38.9516],
        [42.8384],
        [29.1261],
        [27.2385],


Epoch 9: 100%|██████████| 168/168 [06:12<00:00,  2.22s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[25.9496],
        [30.8328],
        [33.0838],
        [44.3285],
        [50.2987],
        [50.7676],
        [31.5405],
        [27.9967],
        [30.2842],
        [38.2947],
        [47.2168],
        [59.0997],
        [60.8545],
        [29.5233],
        [39.6749],
        [36.1138]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[49.8643],
        [55.6555],
        [59.3597],
        [50.2922],
        [24.7303],
        [36.6966],
        [38.7868],
        [25.1249],
        [21.4559],


Epoch 10: 100%|██████████| 168/168 [05:55<00:00,  2.12s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[30.9603],
        [32.7418],
        [41.6003],
        [48.5030],
        [44.7051],
        [56.5455],
        [38.6204],
        [31.8829],
        [35.5690],
        [37.8772],
        [55.7973],
        [64.6411],
        [63.9091],
        [33.3537],
        [48.6581],
        [39.7866]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[55.6420],
        [61.4087],
        [61.0853],
        [53.4124],
        [25.2524],
        [41.6393],
        [44.5573],
        [31.4368],
        [26.2708],


Epoch 11: 100%|██████████| 168/168 [06:13<00:00,  2.22s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[29.7972],
        [32.3020],
        [38.9212],
        [47.2105],
        [42.9820],
        [53.4713],
        [38.8862],
        [32.1546],
        [34.7931],
        [43.2143],
        [53.0095],
        [58.5357],
        [56.2493],
        [30.3451],
        [40.3299],
        [42.5599]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[53.6358],
        [58.1296],
        [57.2820],
        [54.8926],
        [25.1243],
        [40.4860],
        [44.3279],
        [29.9346],
        [28.0531],


Epoch 12: 100%|██████████| 168/168 [06:18<00:00,  2.25s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[25.3905],
        [28.5845],
        [30.4447],
        [42.5104],
        [43.6139],
        [51.5495],
        [34.1091],
        [31.7108],
        [28.0526],
        [34.8734],
        [44.4755],
        [58.4042],
        [59.0642],
        [26.0214],
        [35.4770],
        [32.2010]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[53.9501],
        [53.7736],
        [58.0301],
        [46.4503],
        [22.6746],
        [35.7184],
        [36.1547],
        [23.8304],
        [24.0906],


Epoch 13: 100%|██████████| 168/168 [05:52<00:00,  2.10s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[25.0608],
        [30.9130],
        [32.7959],
        [43.1123],
        [43.8309],
        [49.9562],
        [32.5880],
        [27.7967],
        [33.4325],
        [36.9564],
        [49.4304],
        [55.7815],
        [59.2258],
        [26.8339],
        [38.6327],
        [36.2412]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[53.9720],
        [57.6295],
        [57.3611],
        [50.6145],
        [23.3654],
        [35.6662],
        [38.3431],
        [26.7797],
        [23.9136],


Epoch 14: 100%|██████████| 168/168 [04:43<00:00,  1.69s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[30.1536],
        [33.0792],
        [42.6803],
        [46.7606],
        [43.1874],
        [58.4155],
        [36.6406],
        [30.9524],
        [36.6121],
        [41.1695],
        [56.6582],
        [62.2237],
        [64.5761],
        [40.3739],
        [46.0780],
        [40.6564]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[57.0884],
        [60.8135],
        [63.1931],
        [52.4041],
        [25.1257],
        [45.4268],
        [46.1317],
        [30.9580],
        [28.1923],


Epoch 15: 100%|██████████| 168/168 [04:33<00:00,  1.63s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[26.5229],
        [31.0289],
        [37.4073],
        [44.9091],
        [41.2239],
        [49.8912],
        [31.2721],
        [29.0206],
        [29.1935],
        [34.4585],
        [49.2003],
        [57.5928],
        [56.0982],
        [30.6304],
        [38.3565],
        [33.2101]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[53.0315],
        [54.6628],
        [58.1864],
        [51.7717],
        [22.8979],
        [39.4159],
        [43.8697],
        [29.0644],
        [23.4434],


Epoch 16: 100%|██████████| 168/168 [04:32<00:00,  1.62s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[25.6352],
        [31.5061],
        [36.6336],
        [44.4426],
        [45.9907],
        [51.1479],
        [31.4413],
        [30.0505],
        [28.8234],
        [37.9283],
        [49.0231],
        [57.2475],
        [61.3145],
        [31.2167],
        [40.4912],
        [36.1228]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[51.3881],
        [56.2349],
        [58.5011],
        [50.4534],
        [25.4263],
        [37.2338],
        [38.7931],
        [26.9255],
        [30.9904],


Epoch 17: 100%|██████████| 168/168 [04:20<00:00,  1.55s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[28.3175],
        [34.4910],
        [39.3229],
        [46.1410],
        [47.9356],
        [52.3556],
        [33.8425],
        [30.3984],
        [31.6214],
        [38.3830],
        [50.0426],
        [60.2829],
        [67.6739],
        [31.0807],
        [41.0651],
        [36.8512]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[58.0256],
        [58.5964],
        [63.1193],
        [54.1583],
        [28.5361],
        [41.1372],
        [40.2597],
        [29.0692],
        [24.5049],


Epoch 18: 100%|██████████| 168/168 [04:30<00:00,  1.61s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[26.7859],
        [29.7497],
        [35.6113],
        [44.2712],
        [50.1125],
        [51.5239],
        [32.0412],
        [34.0784],
        [31.5843],
        [40.4283],
        [47.1458],
        [58.5990],
        [62.3280],
        [32.7333],
        [40.6776],
        [36.0447]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[51.0544],
        [55.3356],
        [57.9156],
        [49.2628],
        [26.2052],
        [40.2323],
        [40.0193],
        [27.6176],
        [22.8105],


Epoch 19: 100%|██████████| 168/168 [05:03<00:00,  1.81s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[29.6247],
        [34.7014],
        [36.1593],
        [43.9200],
        [51.5022],
        [52.7147],
        [32.9067],
        [31.7801],
        [31.9979],
        [38.7201],
        [50.4754],
        [59.1262],
        [63.8474],
        [30.8017],
        [43.3754],
        [35.3079]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[52.2426],
        [55.2667],
        [62.1090],
        [47.2078],
        [25.8596],
        [38.5613],
        [40.4715],
        [27.1137],
        [23.1516],


Epoch 20: 100%|██████████| 168/168 [05:43<00:00,  2.05s/it]


tensor([[30.],
        [34.],
        [37.],
        [48.],
        [38.],
        [54.],
        [30.],
        [31.],
        [28.],
        [40.],
        [57.],
        [61.],
        [59.],
        [31.],
        [33.],
        [38.]], dtype=torch.float64)
tensor([[28.9128],
        [33.1246],
        [36.8260],
        [47.3222],
        [52.0226],
        [53.3593],
        [35.5737],
        [30.8013],
        [31.7417],
        [43.3802],
        [50.7536],
        [62.6994],
        [62.0712],
        [35.3292],
        [41.6393],
        [37.4848]])
tensor([[54.],
        [57.],
        [59.],
        [53.],
        [27.],
        [43.],
        [41.],
        [29.],
        [22.],
        [58.],
        [30.],
        [53.],
        [70.],
        [61.],
        [67.],
        [48.]], dtype=torch.float64)
tensor([[55.7828],
        [59.1282],
        [62.2863],
        [52.9000],
        [27.9047],
        [40.2417],
        [43.6211],
        [30.8499],
        [27.9161],


In [22]:
MAE = evaluate(model, test_loader, device)
print(f"Mean Absolute Error: {MAE.item()}")

tensor([[51.],
        [66.],
        [59.],
        [70.],
        [49.],
        [66.],
        [40.],
        [33.],
        [36.],
        [47.],
        [61.],
        [33.],
        [49.],
        [51.],
        [41.],
        [30.]], dtype=torch.float64)
tensor([[51.4234],
        [69.2569],
        [58.6057],
        [70.9119],
        [50.0224],
        [67.7363],
        [40.9583],
        [40.9038],
        [36.0853],
        [42.9509],
        [64.5266],
        [41.1832],
        [46.4910],
        [50.3956],
        [42.7205],
        [31.5594]])
tensor([[48.],
        [30.],
        [43.],
        [58.],
        [70.],
        [39.],
        [28.],
        [49.],
        [34.],
        [29.],
        [34.],
        [48.],
        [29.],
        [32.],
        [57.],
        [30.]], dtype=torch.float64)
tensor([[58.1993],
        [32.7801],
        [41.6715],
        [59.3292],
        [67.7300],
        [48.4527],
        [25.9687],
        [50.3198],
        [34.0857],
