In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
encoder = torch.load('mickeyencoder.pth')

In [3]:
class DOX(nn.Module):
    def __init__(self, encoder, num_classes):
        super(DOX, self).__init__()
        self.encoder = encoder
        self.classifier = nn.Linear(1280, num_classes)

    def forward(self, x):
        with torch.no_grad():
            x = self.encoder(x)
        x = self.classifier(x)
        return x

In [4]:
for param in encoder.parameters():
    param.requires_grad = False

In [5]:
import random
import os
import re
import msgpack
from io import BytesIO
from PIL import Image
import torch
from typing import Union, List, Optional
from pathlib import Path

class MsgPackIterableDataset(torch.utils.data.IterableDataset):
    def __init__(
        self,
        path: Union[str, List[str]],
        key_img_id: str = "id",
        key_img_encoded: str = "image",
        transformation=None,
        shuffle=False,
        split: str = "train",
        split_ratio: float = 0.8,
        cache_size=6 * 4096,
    ):
        super(MsgPackIterableDataset, self).__init__()
        self.path = path
        self.cache_size = cache_size
        self.transformation = transformation
        self.shuffle = shuffle
        self.split = split
        self.split_ratio = split_ratio
        self.seed = random.randint(1, 100)
        self.key_img_id = key_img_id.encode("utf-8")
        self.key_img_encoded = key_img_encoded.encode("utf-8")

        if not isinstance(self.path, (list, set)):
            self.path = [self.path]

        self.shards = self.__init_shards(self.path)
        self.shard_indices = self._split_shards()

    @staticmethod
    def __init_shards(path: Union[str, Path]) -> list:
        shards = []
        for i, p in enumerate(path):
            shards_re = r"shard_(\d+).msg"
            shards_index = [
                int(re.match(shards_re, x).group(1))
                for x in os.listdir(p)
                if re.match(shards_re, x)
            ]
            shards.extend(
                [
                    {
                        "path_index": i,
                        "path": p,
                        "shard_index": s,
                        "shard_path": os.path.join(p, f"shard_{s}.msg"),
                    }
                    for s in shards_index
                ]
            )
        if len(shards) == 0:
            raise ValueError("No shards found")

        return shards

    def _split_shards(self):
        random.seed(self.seed)
        random.shuffle(self.shards)
        split_point = int(len(self.shards) * self.split_ratio)
        if self.split == "train":
            return self.shards[:split_point]
        else:
            return self.shards[split_point:]

    def _process_sample(self, x):
        img = Image.open(BytesIO(x[self.key_img_encoded]))
        if img.mode != "RGB":
            img = img.convert("RGB")

        if self.transformation:
            img = self.transformation(img)

        _id = x[self.key_img_id].decode("utf-8")
        return img, _id

    def __iter__(self):
        shard_indices = list(range(len(self.shard_indices)))

        if self.shuffle:
            random.seed(self.seed)
            random.shuffle(shard_indices)

        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            def split_list(alist, splits=1):
                length = len(alist)
                return [alist[i * length // splits: (i + 1) * length // splits] for i in range(splits)]

            shard_indices_split = split_list(shard_indices, worker_info.num_workers)[worker_info.id]
        else:
            shard_indices_split = shard_indices

        cache = []

        for shard_index in shard_indices_split:
            shard = self.shard_indices[shard_index]

            with open(shard['shard_path'], 'rb') as f:
                unpacker = msgpack.Unpacker(f, max_buffer_size=1024 * 1024 * 1024, raw=True)
                for x in unpacker:
                    if x is None:
                        continue

                    if len(cache) < self.cache_size:
                        cache.append(x)

                    if len(cache) == self.cache_size:
                        if self.shuffle:
                            random.shuffle(cache)
                        while cache:
                            yield self._process_sample(cache.pop())
        if self.shuffle:
            random.shuffle(cache)

        while cache:
            yield self._process_sample(cache.pop())



In [6]:
from src.transforms import make_transforms

crop_size = 224
crop_scale = [0.8, 1.0]
use_gaussian_blur = False
use_horizontal_flip = False
use_color_distortion = False
color_jitter = 0.0

transform = make_transforms(
    crop_size=crop_size,
    crop_scale=crop_scale,
    gaussian_blur=use_gaussian_blur,
    horizontal_flip=use_horizontal_flip,
    color_distortion=use_color_distortion,
    color_jitter=color_jitter)

In [7]:
import json
with open('output_square_ids.json', 'r') as file:
    img_cell = json.load(file)

In [8]:
def get_labels(loc_ids):
    try:
        labels = []
        for loc in loc_ids:
            one_hot = torch.zeros((419, ))
            one_hot[img_cell[loc]] = 1
            labels.append(one_hot.to("cuda:0"))
        return torch.stack(labels).long()
    except:
        return None

In [9]:
model = DOX(encoder, 419)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
root='/nobackup/users/nolangc/mp16'

In [12]:
train_dataset = MsgPackIterableDataset(path=root, transformation=transform, split="train", split_ratio=0.8)
test_dataset = MsgPackIterableDataset(path=root, transformation=transform, split="test", split_ratio=0.8)

In [13]:
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=16,
        drop_last=False,
        pin_memory=False,
        num_workers=4,
        persistent_workers=False)

In [14]:
torch.cuda.empty_cache()

In [23]:
from tqdm import tqdm

def train(model, data_loader, criterion, optimizer, num_epochs, device, log_file="training_log.txt"):
    model.train()  # Set the model to training mode
    model.to(device)  # Ensure the model is on the correct device

    with open(log_file, mode='w') as file:
        file.write("Epoch, Batch, Batch Loss, Average Loss\n")
        file.flush()

        for epoch in range(num_epochs):
            total_loss = 0
            processed_batches = 0

            for batch_idx, (image_data, loc_ids) in enumerate(tqdm(data_loader)):
                try:
                    image_data = image_data.to(device)
                    labels = get_labels(loc_ids)
                    if labels is None:
                        file.write(f"Skipping batch {batch_idx + 1} due to no labels.\n")
                        file.flush()
                        continue
                    labels = labels.to(device).long()

                    optimizer.zero_grad()
                    outputs = model(image_data)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    batch_loss = loss.item()
                    total_loss += batch_loss
                    processed_batches += 1

                    # Log the batch loss to the text file and flush to ensure it's written immediately
                    file.write(f"{epoch + 1}, {batch_idx + 1}, {batch_loss}, \n")
                    file.flush()

                except Exception as e:
                    file.write(f"Error at epoch {epoch + 1}, batch {batch_idx + 1}: {e}\n")
                    file.flush()
                    continue

            if processed_batches > 0:
                avg_loss = total_loss / processed_batches
            else:
                avg_loss = 0

            # Log the average loss at the end of the epoch to the text file
            file.write(f"{epoch + 1}, Epoch Summary, , {avg_loss}\n")
            file.flush()
            print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss}")



In [None]:
train(model, train_loader, criterion, optimizer, 1, "cuda:0")

6243it [44:54,  2.18it/s]