In [11]:
import torch
import torch.nn as nn
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class MyDataset(Dataset):
    
    def __init__(self, annotation_file, data_dir):
        super(MyDataset, self).__init__()
        self.data_dir = data_dir
        self.annotations = self._load_annotations(annotation_file)
        self.image_names = list(self.annotations.keys())
        self.image_list = os.listdir(self.data_dir)
        
        self.transform = transforms.Compose([
                                transforms.RandomCrop((256,256),pad_if_needed=True),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    def _load_annotations(self, annotation_file):
        with open(annotation_file, 'r') as f:
            annotations = json.load(f)
        return annotations

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.data_dir, img_name)
        label = self.annotations[img_name] - 1  # Convert 1-indexed to 0-indexed
        #self.image_list = os.listdir(self.data_dir)

        # Read image and convert to RGB
        #img = torchvision.io.read_image(self.data_dir+self.image_list[idx],ImageReadMode)
        img = torchvision.io.read_image(os.path.join(self.data_dir,self.image_list[idx]),ImageReadMode.RGB)
        
        img = torch.div(img, 255.0)
        
        img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.image_names)
    
# Example usage
train_dataset = MyDataset("train_annos.json", "cars_train/")
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataset = MyDataset("test_annos.json", "cars_test/")
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# Define the training dataset and data loader
training_dataset = MyDataset(annotation_file="train_annos.json", data_dir="cars_train/")
train_loader = DataLoader(training_dataset, batch_size=8, shuffle=True)
    
# Training loop parameters
num_epochs = 5
losses = []

# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0.0
    with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
        for images, labels in train_loader:
            optimizer.zero_grad()   
            outputs = model(images)
            loss = loss_func(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.update(1)
        
        losses.append(epoch_loss / len(train_loader))
        pbar.set_postfix({'Loss': losses[-1]})
        pbar.refresh()
        
        
    
    torch.save(model.state_dict(), "model_weights.pth") 
    

train_dataset = MyDataset("training.txt", "data_dir")
validation_dataset = MyDataset("validation.txt")
test_dataset = MyDataset("test.txt")

train_loader = DataLoader(train_dataset,batch_size=1)
val_loader = DataLoader(validation_dataset,batch_size=1)
test_loader = DataLoader(test_dataset,batch_size=1)



class LitNetwork(pl.LightningModule):
    def __init__(self):
        super(LitNetwork, self).__init__()

        # TODO: setup network here
        
        n_classes = 10 #Change num_classes to the number of classification categories in your dataset
        self.loss_func = torch.nn.CrossEntropyLoss()
        self.val_acc = torchmetrics.Accuracy("multiclass",num_classes=n_classes,average='micro')
        self.test_acc = torchmetrics.Accuracy("multiclass",num_classes=n_classes,average='micro')

    def forward(self, x):
        # TODO: perform the forward pass, which happens when someone calls network(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, data, batch_idx):
        im, label = data[0], data[1]
        outs = self.forward(im)
        loss = self.loss_func(outs, label)
        self.log("train_loss",loss,batch_size=1,sync_dist=True)
        return loss
    
    def validation_step(self, val_data, batch_idx):
        im, label = val_data[0], val_data[1]
        outs = self.forward(im)
        self.val_acc(outs,label)
        self.log("val_acc",self.val_acc,prog_bar=True,on_step=False,on_epoch=True,sync_dist=True)
        return None

    def test_step(self, test_data, batch_idx):
        im, label = test_data[0], test_data[1]
        outs = self.forward(test_data.im)
        self.test_acc(outs,test_data.label)
        self.log("test_acc",self.test_acc,prog_bar=True,on_step=False,on_epoch=True,sync_dist=True)
        return None


model = LitNetwork()
checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1, mode='max')
logger = pl_loggers.TensorBoardLogger(save_dir="my_logs")
#logger = pl_loggers.CSVLogger(save_dir="my_logs",name="my_csv_logs")

device = "mps" # Use 'mps' for Mac M1 or M2 Core, 'gpu' for Windows with Nvidia GPU, or 'cpu' for Windows without Nvidia GPU

trainer = pl.Trainer(max_epochs=10, accelerator=device, callbacks=[checkpoint], logger=logger)
trainer.fit(model,train_loader,val_loader)
    
trainer.test(ckpt_path="best", dataloaders=test_loader)

FileNotFoundError: [Errno 2] No such file or directory: 'train_annos.json'