In [None]:
import lightning as L


class RPSDataModule(L.LightningDataModule):
    def __init__(
        self,
        train_batch_size,
        predict_batch_size,
        train_data_dir=None,
        val_data_dir=None,
        test_data_dir=None,
        predict_data_dir=None,
    ):
        super().__init__()

        self.train_batch_size = train_batch_size
        self.predict_batch_size = predict_batch_size

        self.train_data_dir = train_data_dir
        self.val_data_dir = val_data_dir
        self.test_data_dir = test_data_dir
        self.predict_data_dir = predict_data_dir

    def setup(self, stage):
        if stage == "fit":
            self.train_dataset = init_dataset(self.train_data_dir)
            self.val_dataset = init_dataset(self.val_data_dir)
        elif stage == "validate":
            self.val_dataset = init_dataset(self.val_data_dir)
        elif stage == "test":
            self.test_dataset = init_dataset(self.test_data_dir)
        elif stage == "predict":
            self.predict_dataset = init_predict_dataset(self.predict_data_dir)

    def train_dataloader(self):
        return init_dataloader(self.train_dataset, self.train_batch_size, shuffle=True)

    def val_dataloader(self):
        return init_dataloader(self.val_dataset, self.predict_batch_size, shuffle=False)

    def test_dataloader(self):
        return init_dataloader(self.test_dataset, self.predict_batch_size, shuffle=False)

    def predict_dataloader(self):
        return init_dataloader(self.predict_dataset, self.predict_batch_size, shuffle=False)

In [None]:
class RPSModule(L.LightningModule):
    def __init__(self, feature_extractor, num_classes=3, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.feature_extractor = feature_extractor
        self.classifier = CatBoostClassifier(
            task_type="GPU",
            iterations=2000,
            random_state=42,
            silent=True,
        )

        self.learning_rate = learning_rate
        self.num_classes = num_classes

        # Freeze feature extractor
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        # Initialize lists to store features and labels
        self.train_features = []
        self.train_labels = []
        self.val_features = []
        self.val_labels = []

    def forward(self, x):
        features = self.feature_extractor(x)
        return features

    def training_step(self, batch, batch_idx):
        x, y = batch
        features = self(x)

        # Store features and labels for classifier training
        self.train_features.extend(features.cpu().numpy())
        self.train_labels.extend(y.cpu().numpy())

        return {"loss": torch.tensor(0.0, requires_grad=True)}  # Dummy loss

    def on_train_epoch_end(self):
        # Train CatBoost classifier on accumulated features
        if self.train_features and self.train_labels:
            self.classifier.fit(np.array(self.train_features), np.array(self.train_labels))

            # Clear stored features and labels
            self.train_features.clear()
            self.train_labels.clear()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        features = self(x)

        # Store features and labels for validation
        self.val_features.extend(features.cpu().numpy())
        self.val_labels.extend(y.cpu().numpy())

        return features

    def on_validation_epoch_end(self):
        # Validate using CatBoost classifier
        if self.val_features and self.val_labels:
            val_pred = self.classifier.predict(np.array(self.val_features))
            val_acc = accuracy_score(np.array(self.val_labels), val_pred)

            self.log("val_acc", val_acc, prog_bar=True)

            # Clear stored features and labels
            self.val_features.clear()
            self.val_labels.clear()

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        features = self(x)
        predictions = self.classifier.predict_proba(features.cpu().numpy())
        return predictions, y

    def configure_optimizers(self):
        # Since we're using CatBoost as classifier, we don't need PyTorch optimizer
        # But Lightning requires at least one optimizer
        return optim.Adam(self.parameters(), lr=self.learning_rate)