In [1]:
import os
import numpy as np
import torch
import torchvision
from PIL import Image
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTConfig
import torch.nn.init as init
from torch import optim

In [3]:
class CustomImageDataset(torchvision.datasets.VisionDataset):
    def __init__(self, root, transform=None, target_transform=None):
        super().__init__(root, transform=transform, target_transform=target_transform)
        # Assuming that the images are organized in subfolders representing classes.
        self.classes, self.class_to_idx = self._find_classes(root)
        self.samples = self.make_dataset(root, self.class_to_idx)
        self.targets = [s[1] for s in self.samples]

    def _find_classes(self, dir):
        """
        Finds the class folders in a dataset.
        """
        if not os.path.isdir(dir):
            raise FileNotFoundError(f"Couldn't find directory: {dir}")
        classes = sorted(entry.name for entry in os.scandir(dir) if entry.is_dir())
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        classes = classes[0:]
        return classes, class_to_idx

    def make_dataset(self, directory, class_to_idx):
        instances = []
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir)):
                for fname in fnames:
                    if fname.lower().endswith(('jpg', 'jpeg', 'png', 'bmp', 'tiff', 'tif')):
                        path = os.path.join(root, fname)
                        instances.append((path, class_index))
        return instances

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        
        sample = self.loader(path)
            

        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.samples)


    def loader(self, path):
        try:
            # Open the image file, avoiding automatic resource management to control the closing process
            with open(path, 'rb') as f:
                image = Image.open(path)
                if image.mode in ['I', 'F', 'I;16', 'I;16L', 'I;16B', 'I;16N']:
                    image_array = np.array(image, dtype=np.int32)
                    max_val = image_array.max()
                    if max_val > 0:image = Image.fromarray((image_array / max_val * 255).astype(np.uint8))
                if image.mode != 'RGB': image = image.convert('RGB')

            return image
        except Exception:
            print(f"Failed to load image {path}. Skipping.")
            return None


class ClassificationVisualizeDataset(torch.utils.data.Dataset):

    def __init__(self, data_dir, transforms):
        self.data_dir   = data_dir
        self.transforms = transforms

        # This one-liner basically generates a sorted list of full paths to each image in the test directory
        self.img_paths  = list(map(lambda fname: os.path.join(self.data_dir, fname), sorted(os.listdir(self.data_dir))))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image = Image.open(self.img_paths[idx])
        if image.mode in ['I', 'F', 'I;16', 'I;16L', 'I;16B', 'I;16N']:
            image_array = np.array(image, dtype=np.int32)
            max_val = image_array.max()
            if max_val > 0:image = Image.fromarray((image_array / max_val * 255).astype(np.uint8))
        if image.mode != 'RGB': image = image.convert('RGB')
        return self.transforms(image)

mean=[0.4195, 0.3118, 0.1418]
std=[0.2289, 0.2239, 0.2249]
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomAffine(degrees=[-15, 15], scale=(1/1.2, 1.2)),  # Random stretching
    # torchvision.transforms.GaussianBlur(kernel_size=(7,7),sigma=(0.1,0.2)),
    torchvision.transforms.Resize((224, 224)), 
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize(mean,std)
])
valid_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean,std)
])

TRAIN_DIR = '/ocean/projects/cis230079p/shared/CapStone-VeyTel-2024/datasets/data-clean/train'
VAL_DIR = "/ocean/projects/cis230079p/shared/CapStone-VeyTel-2024/datasets/data-clean/dev"
TEST_DIR = "/ocean/projects/cis230079p/shared/CapStone-VeyTel-2024/datasets/data-clean/test"

train_dataset   = CustomImageDataset(TRAIN_DIR, transform= train_transforms)
valid_dataset   = CustomImageDataset(VAL_DIR, transform= valid_transforms)
test_dataset    = ClassificationVisualizeDataset(TEST_DIR, transforms =valid_transforms)

train_loader = torch.utils.data.DataLoader(
    dataset     = train_dataset,
    batch_size  = 32,
    shuffle     = True,
    num_workers = 10,
    pin_memory = True
)

valid_loader = torch.utils.data.DataLoader(
    dataset     = valid_dataset,
    batch_size  = 32,
    shuffle     = False,
    num_workers = 10,
    pin_memory = True
)

test_loader = torch.utils.data.DataLoader(
    dataset     = test_dataset,
    batch_size  = 32,
    shuffle     = False,
    num_workers = 5,
    pin_memory  = True
)

In [4]:
config = ViTConfig.from_pretrained('google/vit-base-patch16-224', num_labels=3, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1)
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    config=config,
    ignore_mismatched_sizes=True  # ignore the final classifier layer size mismatch
)

init.kaiming_normal_(model.classifier.weight, mode='fan_out')
init.constant_(model.classifier.bias, 0)

classifier_params = ["classifier"]
pretrained_params = []
classifier_parameters = []

for name, param in model.named_parameters():
    if any(cp in name for cp in classifier_params):
        classifier_parameters.append(param)
        weight_decay = 1e-4
    else:
        pretrained_params.append(param)
        weight_decay = 0

optimizer_grouped_parameters = [
    {'params': pretrained_params, 'weight_decay': weight_decay, 'lr': 5e-6},
    {'params': classifier_parameters, 'weight_decay': weight_decay, 'lr': 1e-4}
]

optimizer = optim.Adam(optimizer_grouped_parameters)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# load from checkpoint
def load_checkpoint(filepath, model, optimizer):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])
    start_epoch = checkpoint['epoch']
    best_val_loss = checkpoint['best_val_loss']
    return model, optimizer, start_epoch, best_val_loss

checkpoint_path = 'checkpoints/checkpoint.pth'
model, optimizer, start_epoch, best_val_loss = load_checkpoint(checkpoint_path, model, optimizer)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [28]:
import torch
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

def train(model, train_loader, valid_loader, optimizer, device=device, epochs=10, checkpoint_dir=None, patience=3):
def train(model, train_loader, valid_loader, optimizer, device, epochs, checkpoint_dir, start_epoch=0, best_val_loss=float('inf')):
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=1)

    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    best_val_loss = float('inf')
    # epochs_no_improve = 0 

    for epoch in range(start_epoch, epochs):
        model.train()
        running_loss = 0.0
        total_correct = 0
        total_samples = 0

        batch_bar = tqdm(total=len(train_loader.dataset), desc=f"Epoch {epoch + 1}", position=0)

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            correct = (predicted == labels).sum().item()
            total_correct += correct
            total_samples += inputs.size(0)

            batch_bar.set_postfix(
                acc="{:.04f}%".format(100 * total_correct / total_samples),
                loss="{:.04f}".format(running_loss / total_samples),
                num_correct=total_correct
            )
            batch_bar.update(inputs.size(0))

        batch_bar.close()
        
        # Calculate and append training loss and accuracy
        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = 100 * total_correct / total_samples
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        # Validate and append validation loss and accuracy
        val_loss, val_accuracy = validate(model, valid_loader, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        scheduler.step(val_loss)
        print(f'Current learning rate is {scheduler.get_last_lr()}')

        # Checkpoint saving
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
            }
            torch.save(checkpoint, os.path.join(checkpoint_dir, f"checkpoint.pth"))
            print(f"Checkpoint saved with validation loss {val_loss:.4f}.")
        # else:
        #     epochs_no_improve += 1
        #     if epochs_no_improve >= patience:
        #         print(f"Early stopping triggered after {epoch + 1} epochs.")
        #         break
    
    final_model_path = os.path.join(checkpoint_dir, f'final_model_{epoch + 1}.pth')
    torch.save(model.state_dict(), final_model_path)
    print(f"Final model state saved to {final_model_path}")
    
    # Plot training and validation losses
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.legend()
    plt.show()

def validate(model, valid_loader, device):
    model.eval()
    total_loss = 0.0
    criterion = torch.nn.CrossEntropyLoss()
    total_samples = 0
    total_correct = 0

    for inputs, labels in valid_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)
        total_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += inputs.size(0)

    validation_loss = total_loss / len(valid_loader.dataset)
    validation_accuracy = 100 * total_correct / total_samples
    print(f'Validation Loss: {validation_loss:.4f}')
    print(f'Validation Accuracy: {validation_accuracy:.2f}%')
    return validation_loss, validation_accuracy

In [None]:
train(model, train_loader, valid_loader, optimizer, epochs=30, checkpoint_dir='checkpoints')

Epoch 1: 100%|██████████| 19588/19588 [14:19<00:00, 22.79it/s, acc=54.1556%, loss=2.3711, num_correct=10608]


Validation Loss: 1.1076
Validation Accuracy: 64.95%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 1.1076.


Epoch 2: 100%|██████████| 19588/19588 [17:32<00:00, 18.62it/s, acc=56.1824%, loss=2.0534, num_correct=11005]


Validation Loss: 1.0140
Validation Accuracy: 67.97%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 1.0140.


Epoch 3: 100%|██████████| 19588/19588 [16:47<00:00, 19.43it/s, acc=58.1070%, loss=1.7528, num_correct=11382]


Validation Loss: 0.9792
Validation Accuracy: 68.06%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.9792.


Epoch 4: 100%|██████████| 19588/19588 [14:20<00:00, 22.76it/s, acc=59.1025%, loss=1.5860, num_correct=11577]


Validation Loss: 1.0405
Validation Accuracy: 67.16%
Current learning rate is [5e-06, 0.0001]


Epoch 5: 100%|██████████| 19588/19588 [16:21<00:00, 19.95it/s, acc=58.6022%, loss=1.5451, num_correct=11479]


Validation Loss: 0.8732
Validation Accuracy: 68.95%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.8732.


Epoch 6: 100%|██████████| 19588/19588 [14:19<00:00, 22.78it/s, acc=60.7821%, loss=1.3400, num_correct=11906]


Validation Loss: 0.9057
Validation Accuracy: 66.58%
Current learning rate is [5e-06, 0.0001]


Epoch 7: 100%|██████████| 19588/19588 [20:46<00:00, 15.71it/s, acc=61.5785%, loss=1.2618, num_correct=12062] 


Validation Loss: 0.8224
Validation Accuracy: 67.65%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.8224.


Epoch 8: 100%|██████████| 19588/19588 [20:19<00:00, 16.06it/s, acc=61.7929%, loss=1.1951, num_correct=12104]  


Validation Loss: 0.8122
Validation Accuracy: 67.81%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.8122.


Epoch 9: 100%|██████████| 19588/19588 [17:49<00:00, 18.31it/s, acc=62.9467%, loss=1.0992, num_correct=12330] 


Validation Loss: 0.8862
Validation Accuracy: 68.38%
Current learning rate is [5e-06, 0.0001]


Epoch 10: 100%|██████████| 19588/19588 [14:16<00:00, 22.88it/s, acc=63.1509%, loss=1.0605, num_correct=12370]


Validation Loss: 0.7719
Validation Accuracy: 69.85%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.7719.


Epoch 11: 100%|██████████| 19588/19588 [14:45<00:00, 22.12it/s, acc=63.9473%, loss=1.0160, num_correct=12526]


Validation Loss: 0.7574
Validation Accuracy: 71.16%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.7574.


Epoch 12: 100%|██████████| 19588/19588 [14:51<00:00, 21.96it/s, acc=64.4068%, loss=0.9718, num_correct=12616]


Validation Loss: 0.7355
Validation Accuracy: 70.59%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.7355.


Epoch 13: 100%|██████████| 19588/19588 [19:49<00:00, 16.46it/s, acc=64.8713%, loss=0.9240, num_correct=12707]


Validation Loss: 0.7197
Validation Accuracy: 70.42%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.7197.


Epoch 14: 100%|██████████| 19588/19588 [14:01<00:00, 23.28it/s, acc=65.1521%, loss=0.9118, num_correct=12762]


Validation Loss: 0.7035
Validation Accuracy: 71.16%
Current learning rate is [5e-06, 0.0001]
Checkpoint saved with validation loss 0.7035.


Epoch 15: 100%|██████████| 19588/19588 [14:10<00:00, 23.03it/s, acc=65.9945%, loss=0.8779, num_correct=12927]


Validation Loss: 0.7094
Validation Accuracy: 71.32%
Current learning rate is [5e-06, 0.0001]


Epoch 16: 100%|██████████| 19588/19588 [13:56<00:00, 23.42it/s, acc=66.5203%, loss=0.8418, num_correct=13030]


Validation Loss: 0.7128
Validation Accuracy: 71.24%
Current learning rate is [5.000000000000001e-07, 1e-05]


Epoch 17: 100%|██████████| 19588/19588 [14:12<00:00, 22.96it/s, acc=67.4750%, loss=0.7928, num_correct=13217]


Validation Loss: 0.6760
Validation Accuracy: 71.65%
Current learning rate is [5.000000000000001e-07, 1e-05]
Checkpoint saved with validation loss 0.6760.


Epoch 18: 100%|██████████| 19588/19588 [14:23<00:00, 22.69it/s, acc=68.3122%, loss=0.7766, num_correct=13381]


Validation Loss: 0.6804
Validation Accuracy: 70.75%
Current learning rate is [5.000000000000001e-07, 1e-05]


Epoch 19: 100%|██████████| 19588/19588 [15:04<00:00, 21.66it/s, acc=68.6339%, loss=0.7690, num_correct=13444]


Validation Loss: 0.6897
Validation Accuracy: 70.67%
Current learning rate is [5.000000000000001e-08, 1.0000000000000002e-06]


Epoch 20:  18%|█▊        | 3520/19588 [02:33<05:15, 50.86it/s, acc=66.7045%, loss=0.7910, num_correct=2348]

In [7]:
train(model, train_loader, valid_loader, epochs=50)

Epoch 1: 100%|██████████| 19588/19588 [14:06<00:00, 23.13it/s, acc=48.6829%, loss=4.3610, lr=0.000010, num_correct=9536]


Validation Loss: 1.4560
Validation Accuracy: 63.73%
Checkpoint saved with validation loss 1.4560.


Epoch 2: 100%|██████████| 19588/19588 [24:33<00:00, 13.30it/s, acc=53.8085%, loss=2.8218, lr=0.000010, num_correct=10540]


Validation Loss: 1.3408
Validation Accuracy: 66.01%
Checkpoint saved with validation loss 1.3408.


Epoch 3: 100%|██████████| 19588/19588 [13:13<00:00, 24.69it/s, acc=55.9220%, loss=2.2320, lr=0.000010, num_correct=10954]


Validation Loss: 1.1789
Validation Accuracy: 67.40%
Checkpoint saved with validation loss 1.1789.


Epoch 4: 100%|██████████| 19588/19588 [14:52<00:00, 21.95it/s, acc=57.3668%, loss=1.8392, lr=0.000010, num_correct=11237]


Validation Loss: 1.1031
Validation Accuracy: 66.91%
Checkpoint saved with validation loss 1.1031.


Epoch 5: 100%|██████████| 19588/19588 [12:57<00:00, 25.19it/s, acc=58.3572%, loss=1.6220, lr=0.000010, num_correct=11431]


Validation Loss: 0.9683
Validation Accuracy: 66.83%
Checkpoint saved with validation loss 0.9683.


Epoch 6: 100%|██████████| 19588/19588 [12:53<00:00, 25.34it/s, acc=59.9653%, loss=1.3794, lr=0.000010, num_correct=11746]


Validation Loss: 0.9495
Validation Accuracy: 67.57%
Checkpoint saved with validation loss 0.9495.


Epoch 7: 100%|██████████| 19588/19588 [12:53<00:00, 25.33it/s, acc=60.5830%, loss=1.2771, lr=0.000010, num_correct=11867]


Validation Loss: 0.9871
Validation Accuracy: 67.24%


Epoch 8: 100%|██████████| 19588/19588 [12:35<00:00, 25.92it/s, acc=61.6398%, loss=1.1478, lr=0.000010, num_correct=12074]


Validation Loss: 0.9486
Validation Accuracy: 67.57%
Checkpoint saved with validation loss 0.9486.


Epoch 9: 100%|██████████| 19588/19588 [13:51<00:00, 23.56it/s, acc=62.4464%, loss=1.0897, lr=0.000010, num_correct=12232] 


Validation Loss: 0.9770
Validation Accuracy: 67.48%


Epoch 10: 100%|██████████| 19588/19588 [12:39<00:00, 25.78it/s, acc=63.1305%, loss=1.0107, lr=0.000010, num_correct=12366]


Validation Loss: 0.9167
Validation Accuracy: 65.28%
Checkpoint saved with validation loss 0.9167.


Epoch 11: 100%|██████████| 19588/19588 [13:23<00:00, 24.38it/s, acc=63.6461%, loss=0.9790, lr=0.000010, num_correct=12467]


Validation Loss: 0.8434
Validation Accuracy: 68.87%
Checkpoint saved with validation loss 0.8434.


Epoch 12: 100%|██████████| 19588/19588 [13:23<00:00, 24.38it/s, acc=64.7488%, loss=0.9038, lr=0.000010, num_correct=12683]


Validation Loss: 0.8074
Validation Accuracy: 66.42%
Checkpoint saved with validation loss 0.8074.


Epoch 13: 100%|██████████| 19588/19588 [12:47<00:00, 25.51it/s, acc=65.2593%, loss=0.8670, lr=0.000010, num_correct=12783]


Validation Loss: 1.0228
Validation Accuracy: 66.18%


Epoch 14: 100%|██████████| 19588/19588 [13:25<00:00, 24.33it/s, acc=65.5708%, loss=0.8456, lr=0.000010, num_correct=12844]


Validation Loss: 0.7685
Validation Accuracy: 70.75%
Checkpoint saved with validation loss 0.7685.


Epoch 15: 100%|██████████| 19588/19588 [19:09<00:00, 17.04it/s, acc=67.0155%, loss=0.7993, lr=0.000010, num_correct=13127]


Validation Loss: 0.8110
Validation Accuracy: 68.63%


Epoch 16: 100%|██████████| 19588/19588 [13:26<00:00, 24.28it/s, acc=66.7041%, loss=0.8071, lr=0.000010, num_correct=13066]


Validation Loss: 0.7470
Validation Accuracy: 68.46%
Checkpoint saved with validation loss 0.7470.


Epoch 17: 100%|██████████| 19588/19588 [13:27<00:00, 24.26it/s, acc=67.0615%, loss=0.7808, lr=0.000010, num_correct=13136]


Validation Loss: 0.7553
Validation Accuracy: 69.93%


Epoch 18: 100%|██████████| 19588/19588 [27:57<00:00, 11.68it/s, acc=68.0978%, loss=0.7469, lr=0.000010, num_correct=13339]


Validation Loss: 0.7576
Validation Accuracy: 66.99%


Epoch 19: 100%|██████████| 19588/19588 [13:15<00:00, 24.64it/s, acc=69.6804%, loss=0.6845, lr=0.000001, num_correct=13649]


Validation Loss: 0.7101
Validation Accuracy: 70.83%
Checkpoint saved with validation loss 0.7101.


Epoch 20: 100%|██████████| 19588/19588 [13:45<00:00, 23.72it/s, acc=70.2573%, loss=0.6796, lr=0.000001, num_correct=13762]


Validation Loss: 0.7674
Validation Accuracy: 68.14%


Epoch 21: 100%|██████████| 19588/19588 [12:52<00:00, 25.34it/s, acc=70.1654%, loss=0.6759, lr=0.000001, num_correct=13744]


Validation Loss: 0.7447
Validation Accuracy: 70.92%


Epoch 22: 100%|██████████| 19588/19588 [23:27<00:00, 13.91it/s, acc=70.4002%, loss=0.6683, lr=0.000000, num_correct=13790]  


Validation Loss: 0.7253
Validation Accuracy: 70.34%


Epoch 23: 100%|██████████| 19588/19588 [14:15<00:00, 22.89it/s, acc=70.1297%, loss=0.6771, lr=0.000000, num_correct=13737]


Validation Loss: 0.7232
Validation Accuracy: 70.67%


Epoch 24: 100%|██████████| 19588/19588 [12:39<00:00, 25.80it/s, acc=70.9108%, loss=0.6665, lr=0.000000, num_correct=13890]


Validation Loss: 0.7243
Validation Accuracy: 70.42%


Epoch 25: 100%|██████████| 19588/19588 [12:45<00:00, 25.60it/s, acc=70.5432%, loss=0.6664, lr=0.000000, num_correct=13818]


Validation Loss: 0.7247
Validation Accuracy: 70.51%


Epoch 26: 100%|██████████| 19588/19588 [12:41<00:00, 25.72it/s, acc=70.6300%, loss=0.6685, lr=0.000000, num_correct=13835]


Validation Loss: 0.7240
Validation Accuracy: 70.51%


Epoch 27: 100%|██████████| 19588/19588 [13:18<00:00, 24.52it/s, acc=70.5891%, loss=0.6703, lr=0.000000, num_correct=13827]


Validation Loss: 0.7249
Validation Accuracy: 70.51%


Epoch 28: 100%|██████████| 19588/19588 [13:03<00:00, 25.00it/s, acc=70.6861%, loss=0.6630, lr=0.000000, num_correct=13846]


Validation Loss: 0.7240
Validation Accuracy: 70.51%


  image_array = np.array(image, dtype=np.int32)


Failed to load image /ocean/projects/cis230079p/shared/CapStone-VeyTel-2024/datasets/data-clean/train/1/18872_bimcv_pos.png. Skipping.


KeyboardInterrupt: 

  image_array = np.array(image, dtype=np.int32)


Failed to load image /ocean/projects/cis230079p/shared/CapStone-VeyTel-2024/datasets/data-clean/train/1/23119_bimcv_pos.png. Skipping.


Epoch 29:  20%|█▉        | 3840/19588 [02:47<03:41, 71.09it/s, acc=69.9479%, loss=0.6773, lr=0.000000, num_correct=2686]

In [5]:
train(model, train_loader, valid_loader, epochs=20)

Epoch 1: 100%|██████████| 19588/19588 [14:27<00:00, 22.57it/s, acc=60.8893%, loss=0.8100, lr=0.000010, num_correct=11927]


Validation Loss: 0.7639
Validation Accuracy: 63.81%
Checkpoint saved with validation loss 0.7639.


Epoch 2: 100%|██████████| 19588/19588 [14:08<00:00, 23.10it/s, acc=66.4591%, loss=0.7116, lr=0.000010, num_correct=13018]


Validation Loss: 0.7398
Validation Accuracy: 65.52%
Checkpoint saved with validation loss 0.7398.


Epoch 3: 100%|██████████| 19588/19588 [13:04<00:00, 24.96it/s, acc=68.3990%, loss=0.6792, lr=0.000010, num_correct=13398]


Validation Loss: 0.7235
Validation Accuracy: 65.60%
Checkpoint saved with validation loss 0.7235.


Epoch 4: 100%|██████████| 19588/19588 [14:54<00:00, 21.89it/s, acc=69.9816%, loss=0.6463, lr=0.000010, num_correct=13708]


Validation Loss: 0.6827
Validation Accuracy: 67.32%
Checkpoint saved with validation loss 0.6827.


Epoch 5: 100%|██████████| 19588/19588 [15:41<00:00, 20.81it/s, acc=70.3849%, loss=0.6329, lr=0.000010, num_correct=13787]


Validation Loss: 0.6866
Validation Accuracy: 67.97%


Epoch 6: 100%|██████████| 19588/19588 [13:52<00:00, 23.54it/s, acc=71.2528%, loss=0.6102, lr=0.000010, num_correct=13957]


Validation Loss: 0.6796
Validation Accuracy: 69.20%
Checkpoint saved with validation loss 0.6796.


Epoch 7: 100%|██████████| 19588/19588 [13:44<00:00, 23.76it/s, acc=71.9369%, loss=0.5988, lr=0.000010, num_correct=14091]


Validation Loss: 0.6882
Validation Accuracy: 69.85%


Epoch 8: 100%|██████████| 19588/19588 [14:29<00:00, 22.52it/s, acc=72.9784%, loss=0.5834, lr=0.000010, num_correct=14295] 


Validation Loss: 0.6814
Validation Accuracy: 67.73%
Early stopping triggered after 8 epochs.
Final model state saved to checkpoints/final_model_8.pth


In [4]:
train(model, train_loader, valid_loader, epochs=20)

Epoch 1: 100%|██████████| 19588/19588 [09:14<00:00, 35.33it/s, acc=51.4652%, loss=0.9116, lr=0.0010000, num_correct=10081]

Training loss: 0.9116





Validation Loss: 0.8618
Validation Accuracy: 57.03%
Epoch 1: Current Learning Rate: 0.001
Checkpoint saved at epoch 1 with validation loss 0.8618


Epoch 2: 100%|██████████| 19588/19588 [09:22<00:00, 34.83it/s, acc=53.4715%, loss=0.8753, lr=0.0010000, num_correct=10474]

Training loss: 0.8753





Validation Loss: 0.8825
Validation Accuracy: 44.04%
Epoch 2: Current Learning Rate: 0.001


Epoch 3: 100%|██████████| 19588/19588 [10:18<00:00, 31.65it/s, acc=54.4262%, loss=0.8603, lr=0.0010000, num_correct=10661]

Training loss: 0.8603





Validation Loss: 0.9135
Validation Accuracy: 42.97%
Epoch 3: Current Learning Rate: 0.0001


Epoch 4: 100%|██████████| 19588/19588 [09:26<00:00, 34.59it/s, acc=55.5391%, loss=0.8399, lr=0.0001000, num_correct=10879]

Training loss: 0.8399





Validation Loss: 0.8272
Validation Accuracy: 58.25%
Epoch 4: Current Learning Rate: 0.0001
Checkpoint saved at epoch 4 with validation loss 0.8272


Epoch 5: 100%|██████████| 19588/19588 [09:05<00:00, 35.92it/s, acc=56.4070%, loss=0.8282, lr=0.0001000, num_correct=11049]

Training loss: 0.8282





Validation Loss: 0.8262
Validation Accuracy: 57.76%
Epoch 5: Current Learning Rate: 0.0001
Checkpoint saved at epoch 5 with validation loss 0.8262


Epoch 6: 100%|██████████| 19588/19588 [19:43<00:00, 16.55it/s, acc=56.5244%, loss=0.8229, lr=0.0001000, num_correct=11072] 

Training loss: 0.8229





Validation Loss: 0.8310
Validation Accuracy: 58.82%
Epoch 6: Current Learning Rate: 0.0001


Epoch 7: 100%|██████████| 19588/19588 [10:04<00:00, 32.38it/s, acc=57.3259%, loss=0.8199, lr=0.0001000, num_correct=11229]

Training loss: 0.8199





Validation Loss: 0.8142
Validation Accuracy: 58.74%
Epoch 7: Current Learning Rate: 0.0001
Checkpoint saved at epoch 7 with validation loss 0.8142


Epoch 8: 100%|██████████| 19588/19588 [09:28<00:00, 34.47it/s, acc=57.4076%, loss=0.8182, lr=0.0001000, num_correct=11245]

Training loss: 0.8182





Validation Loss: 0.8138
Validation Accuracy: 58.17%
Epoch 8: Current Learning Rate: 0.0001
Checkpoint saved at epoch 8 with validation loss 0.8138


Epoch 9: 100%|██████████| 19588/19588 [11:30<00:00, 28.38it/s, acc=57.4127%, loss=0.8155, lr=0.0001000, num_correct=11246]

Training loss: 0.8155





Validation Loss: 0.8178
Validation Accuracy: 58.01%
Epoch 9: Current Learning Rate: 0.0001


Epoch 10: 100%|██████████| 19588/19588 [09:11<00:00, 35.54it/s, acc=57.4331%, loss=0.8135, lr=0.0001000, num_correct=11250]

Training loss: 0.8135





Validation Loss: 0.8112
Validation Accuracy: 58.42%
Epoch 10: Current Learning Rate: 0.0001
Checkpoint saved at epoch 10 with validation loss 0.8112


Epoch 11: 100%|██████████| 19588/19588 [10:34<00:00, 30.89it/s, acc=57.6935%, loss=0.8081, lr=0.0001000, num_correct=11301]

Training loss: 0.8081





Validation Loss: 0.8158
Validation Accuracy: 59.15%
Epoch 11: Current Learning Rate: 0.0001


Epoch 12: 100%|██████████| 19588/19588 [09:34<00:00, 34.08it/s, acc=57.9845%, loss=0.8059, lr=0.0001000, num_correct=11358]

Training loss: 0.8059





Validation Loss: 0.8138
Validation Accuracy: 58.42%
Epoch 12: Current Learning Rate: 1e-05


Epoch 13: 100%|██████████| 19588/19588 [09:43<00:00, 33.55it/s, acc=58.7247%, loss=0.7963, lr=0.0000100, num_correct=11503]

Training loss: 0.7963





Validation Loss: 0.8083
Validation Accuracy: 59.07%
Epoch 13: Current Learning Rate: 1e-05
Checkpoint saved at epoch 13 with validation loss 0.8083


Epoch 14: 100%|██████████| 19588/19588 [10:58<00:00, 29.76it/s, acc=58.7554%, loss=0.7934, lr=0.0000100, num_correct=11509]

Training loss: 0.7934





Validation Loss: 0.8078
Validation Accuracy: 59.97%
Epoch 14: Current Learning Rate: 1e-05
Checkpoint saved at epoch 14 with validation loss 0.8078


Epoch 15: 100%|██████████| 19588/19588 [10:35<00:00, 30.84it/s, acc=59.0412%, loss=0.7921, lr=0.0000100, num_correct=11565]

Training loss: 0.7921





Validation Loss: 0.8090
Validation Accuracy: 58.66%
Epoch 15: Current Learning Rate: 1e-05


Epoch 16: 100%|██████████| 19588/19588 [10:18<00:00, 31.66it/s, acc=59.2046%, loss=0.7913, lr=0.0000100, num_correct=11597]

Training loss: 0.7913





Validation Loss: 0.8125
Validation Accuracy: 59.40%
Epoch 16: Current Learning Rate: 1.0000000000000002e-06


Epoch 17: 100%|██████████| 19588/19588 [10:10<00:00, 32.06it/s, acc=59.1434%, loss=0.7906, lr=0.0000010, num_correct=11585]

Training loss: 0.7906





Validation Loss: 0.8096
Validation Accuracy: 59.48%
Epoch 17: Current Learning Rate: 1.0000000000000002e-06


Epoch 18: 100%|██████████| 19588/19588 [11:06<00:00, 29.38it/s, acc=59.0668%, loss=0.7899, lr=0.0000010, num_correct=11570]

Training loss: 0.7899





Validation Loss: 0.8097
Validation Accuracy: 59.64%
Epoch 18: Current Learning Rate: 1.0000000000000002e-07


Epoch 19: 100%|██████████| 19588/19588 [09:06<00:00, 35.84it/s, acc=59.0617%, loss=0.7903, lr=0.0000001, num_correct=11569]

Training loss: 0.7903





Validation Loss: 0.8096
Validation Accuracy: 59.64%
Epoch 19: Current Learning Rate: 1.0000000000000002e-07


Epoch 20: 100%|██████████| 19588/19588 [10:34<00:00, 30.88it/s, acc=59.0259%, loss=0.7898, lr=0.0000001, num_correct=11562]

Training loss: 0.7898





Validation Loss: 0.8097
Validation Accuracy: 59.64%
Epoch 20: Current Learning Rate: 1.0000000000000004e-08
Final model state saved to checkpoints/final_model_20.pth


In [None]:
train(model, train_loader, valid_loader, epochs=20)

Epoch 1: 100%|██████████| 19588/19588 [09:08<00:00, 35.72it/s, acc=51.5009%, loss=0.9166, lr=0.0010000, num_correct=10088]

Training loss: 0.9166





Validation Loss: 0.9312
Validation Accuracy: 47.39%
Epoch 1: Current Learning Rate: 0.001
Checkpoint saved at epoch 1 with validation loss 0.9312


Epoch 2: 100%|██████████| 19588/19588 [08:53<00:00, 36.70it/s, acc=53.1039%, loss=0.8793, lr=0.0010000, num_correct=10402]

Training loss: 0.8793





Validation Loss: 0.8574
Validation Accuracy: 57.52%
Epoch 2: Current Learning Rate: 0.001
Checkpoint saved at epoch 2 with validation loss 0.8574


Epoch 3: 100%|██████████| 19588/19588 [08:38<00:00, 37.77it/s, acc=53.3235%, loss=0.8870, lr=0.0010000, num_correct=10445]

Training loss: 0.8870





Validation Loss: 0.8467
Validation Accuracy: 57.11%
Epoch 3: Current Learning Rate: 0.001
Checkpoint saved at epoch 3 with validation loss 0.8467


Epoch 4: 100%|██████████| 19588/19588 [08:37<00:00, 37.89it/s, acc=54.2883%, loss=0.8701, lr=0.0010000, num_correct=10634]

Training loss: 0.8701



