In [1]:
import torch
import torch.nn as nn
import numpy as np
import mixmatch

In [2]:
class ExampleDataset(torch.utils.data.Dataset):

    def __init__(self, x, y=None):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        sample = dict(features=torch.tensor(self.x[idx], dtype=torch.float32))
        if self.y is not None:
            sample['targets'] = torch.tensor(self.y[idx], dtype=torch.float32)
        return sample

In [3]:
class MyModel(nn.Module):
    
    def __init__(self, num_features):
        
        super().__init__()
        self.classifier = nn.Linear(num_features, 1)
        
    def forward(self, x):
        
        return self.classifier(x.view(x.size(0), -1)).squeeze()

In [4]:
dataset_labeled = ExampleDataset(np.random.random((1000, 10, 10)), np.ones(1000))
dataset_unlabeled = ExampleDataset(np.ones((1000, 10, 10)))
model = MyModel(10 * 10)
batch_size = 32
steps_per_epoch = 10
output_transform = torch.sigmoid

loader_labeled = torch.utils.data.DataLoader(
    dataset_labeled,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True
)

loader_mixmatch = mixmatch.MixmatchLoader(
    loader_labeled, 
    dataset_unlabeled, 
    model, 
    output_transform,
    K=2,
    T=0.5,
    alpha=0.75
)

In [5]:
criterion = mixmatch.get_mixmatch_loss(
    criterion_labeled=nn.BCEWithLogitsLoss(), 
    output_transform=output_transform, 
    K=2,
    weight_unlabeled=100.,
    criterion_unlabeled=nn.MSELoss()
)

In [6]:
batch = next(iter(loader_mixmatch))
batch['features'].shape, batch['targets'].shape

(torch.Size([96, 10, 10]), torch.Size([96]))

In [7]:
logits = model(batch['features'])
logits.shape

torch.Size([96])

In [8]:
loss = criterion(logits, batch['targets'])
loss

tensor(1.3509, grad_fn=<AddBackward0>)