In [None]:
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import random
import copy

In [None]:
gpu_ids = []
device_names = []
if torch.cuda.is_available():
    for gpu_id in range(torch.cuda.device_count()):
        gpu_ids += [gpu_id]
        device_names += [torch.cuda.get_device_name(gpu_id)]
print(gpu_ids)
print(device_names)

if len(gpu_ids) > 1:
    device = 'cuda:' + str(gpu_ids[0])  # 여기서 gpu 번호 고르기
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(device)

In [None]:
num_samples = 500
# num_samples = (1000,1000)

points, labels = make_moons(n_samples=num_samples, shuffle=True ,noise=0.1)

In [None]:
points_l = np.array((
    (-0.9, 0.25),
    (0.4, 1.05),
    (0.9, 0.05),
    # (0.2, 0.1), # class 2 왼쪽 끝
    (0.7, -0.4),
    (0.85, -0.35),
    (1.9, 0.2)
))
labels_l = np.array((0,0,0,1,1,1))

In [None]:
plt.scatter(points[:,0],points[:,1], s=10)
plt.scatter(points_l[:,0],points_l[:,1], s=30)

In [None]:
offset = random.randint(-10, 10)
while(offset == 0):
    offset = random.randint(-10, 10)
offset = 0

points += offset
points_l += offset

In [None]:
plt.scatter(points[:,0],points[:,1], s=10)
plt.scatter(points_l[:,0],points_l[:,1], s=30)

In [None]:
# gridshape = (3,4)
# marker_size = 0.1
# num = 12

# jr = num//4
# for i in range(4):
#     for j in range(jr):
#         loc = (j,i)
#         ax = plt.subplot2grid(gridshape, loc)
#         point_noise = points + (torch.rand(size=points.shape).numpy()-0.5) * (0.1*((i+1)+(jr+1)*j))
#         plt.scatter(point_noise[:,0],point_noise[:,1], s=10)
#         plt.scatter(points[:,0],points[:,1], s=10)
#         plt.title(((i+1)+(jr+1)*j))

# ax.figure.set(figwidth=28, figheight=6 * jr)
# plt.show()

In [None]:
class TMClassifier(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(in_features=2, out_features=8),
            nn.Sigmoid(),
            # nn.ReLU(),
            # nn.Linear(in_features=8, out_features=8),
            # nn.Sigmoid(),
            # nn.ReLU(),
            nn.Linear(in_features=8, out_features=8),
            nn.Sigmoid(),
            # nn.ReLU(),
            nn.Linear(in_features=8, out_features=2)
        )

        # self.init_weight()

    def init_weight(self):
        for layer in self.layers:
            if(isinstance(layer, nn.Linear)):
                nn.init.uniform_(layer.weight, -0.1, 0.1)

    def forward(self, x):
        return self.layers(x)

In [None]:
class MyDataset(Dataset):
    def __init__(self, x_data, y_data, transform=None):
        self.x_data = torch.tensor(x_data, dtype=torch.float32)
        self.y_data = torch.tensor(y_data, dtype=torch.long)
        self.transform = transform
        self.len = len(x_data)
    
    def __getitem__(self, index):
        sample = self.x_data[index], self.y_data[index]
        
        if self.transform:
            sample = self.transform(sample)   #self.transform이 None이 아니라면 전처리
        return sample 
    
    def __len__(self):
        return self.len

In [None]:
labeledData = MyDataset(points_l, labels_l)
unlabeledData = MyDataset(points, labels)

labeledLoader = DataLoader(labeledData, batch_size=6, shuffle=True)
unlabeledLoader = DataLoader(unlabeledData, batch_size=500, shuffle=True)

In [None]:
x = np.linspace(-1.5+offset, 2.5+offset, num=800)
y = np.linspace(-1.0+offset, 1.5+offset, num=500)

X,Y = np.meshgrid(x,y)

marker_size = 0.1
# plt.scatter(X, Y)

In [None]:
arr = np.stack((X.flatten(),Y.flatten()),axis=1)

In [None]:
renderData = MyDataset(arr, np.zeros_like(arr))

renderLoader = DataLoader(renderData, batch_size=1024, shuffle=False, drop_last=False)

In [None]:
def meta_pseudo_labels_step(teacher_model,
                            teacher_optimizer,
                            student_model,
                            student_optimizer,
                            uo_data,
                            ua_data,
                            l_data,
                            l_label):
    teacher_optimizer.zero_grad()
    student_optimizer.zero_grad()

    loss_fn = nn.CrossEntropyLoss()
    threshold = 0.6
    lambda_u = 8

    #### teacher UDA loss
    output_t_uo = teacher_model(uo_data)
    soft_pseudo_label = nn.functional.softmax(output_t_uo / 0.75, dim=-1).detach()
    
    output_t_ua = teacher_model(ua_data)
    loss_t_u = torch.mul(soft_pseudo_label, torch.log_softmax(output_t_ua, dim=-1))
    max_probs, _ = torch.max(soft_pseudo_label, dim=-1, keepdim=True)
    mask = torch.greater_equal(max_probs, threshold).type(torch.float32)
    
    loss_t_u = torch.mean(torch.mul(-loss_t_u, mask)) # consistency loss

    #### teacher supervised loss
    output_t_l = teacher_model(l_data)
    loss_t_supervised = loss_fn(nn.functional.softmax(output_t_l, dim=-1), l_label)

    #### student performance old
    output_l_old = student_model(l_data)
    loss_l_old = loss_fn(nn.functional.softmax(output_l_old, dim=-1),
                            l_label).detach()
    
    #### student loss
    output_ua = student_model(ua_data)
    student_loss = loss_fn(nn.functional.softmax(output_ua, dim=-1),
                   torch.argmax(soft_pseudo_label, dim=-1).long())
    
    #### student update
    student_optimizer.zero_grad()
    student_loss.backward()
    student_optimizer.step()

    #### student performance new
    output_l_new =student_model(l_data)
    loss_l_new = loss_fn(nn.functional.softmax(output_l_new, dim=-1),
                            l_label).detach()
    
    #### compute MPL loss
    dot_product = (loss_l_new - loss_l_old)
    loss_t_mpl = torch.argmax(nn.functional.softmax(output_t_ua, dim=-1), dim=-1, keepdim=True)
    loss_t_mpl = -torch.mean(torch.mul(loss_t_mpl, nn.functional.log_softmax(output_t_ua, dim=-1)))
    loss_t_mpl = loss_t_mpl * dot_product

    #### teacher update
    teacher_optimizer.zero_grad()
    teacher_loss = loss_t_supervised + loss_t_u + loss_t_mpl
    teacher_loss.backward()
    teacher_optimizer.step()

    return teacher_loss.detach().numpy(), student_loss.detach().numpy(), loss_t_mpl.detach().numpy()

In [None]:
unlabeled_iterator = unlabeledLoader.__iter__()
labeled_iterator = labeledLoader.__iter__()

supervised_mpl = TMClassifier()
metaPseudoLabel = TMClassifier()

supervised_mpl.train(True)
metaPseudoLabel.train(True)

lr = 2e-1
opt_st = torch.optim.Adam(supervised_mpl.parameters(), lr=lr)
opt_mpl = torch.optim.Adam(metaPseudoLabel.parameters(), lr=lr)

epoch = 0
max_epoch = 2e5+1

best_acc = 0.8

noise = 0.2

while(True):
    try:
        uo_point, _ = unlabeled_iterator.__next__()
        ua_point = uo_point + torch.rand(u_point.shape)*noise
    except:
        unlabeled_iterator = unlabeledLoader.__iter__()
        u_point, _ = unlabeled_iterator.__next__()
        ua_point = uo_point + torch.rand(u_point.shape)*noise
    try:
        l_point, l_label = labeled_iterator.__next__()
    except:
        labeled_iterator = labeledLoader.__iter__()
        l_point, l_label = labeled_iterator.__next__()

    
    supervised_mpl.train(True)
    metaPseudoLabel.train(True)
    
    step = meta_pseudo_labels_step(teacher_model=supervised_mpl,
                                   teacher_optimizer=opt_st,
                                   student_model=metaPseudoLabel,
                                   student_optimizer=opt_mpl,
                                   uo_data=uo_point,
                                   ua_data=ua_point,
                                   l_data=l_point,
                                   l_label=l_label)
    
    supervised_mpl.train(False)
    metaPseudoLabel.train(False)

    infer = supervised_mpl(torch.tensor(points).type(torch.float32))
    teacher_acc = sum((torch.argmax(infer, -1).numpy() == labels).astype(int)) / num_samples
    
    infer = metaPseudoLabel(torch.tensor(points).type(torch.float32))
    student_acc = sum((torch.argmax(infer, -1).numpy() == labels).astype(int)) / num_samples

    if(student_acc > best_acc):
            best_acc = student_acc
            best_student_model = copy.deepcopy(metaPseudoLabel)

    if float(student_acc) > 0.99:
        break
    
    if(epoch%1e3 == 0):
        print(f'Epoch: {epoch} \tTeacher Loss: {step[0].item():.4f} \tStudent Loss: {step[1].item():.4f} \tMPL Loss: {step[2].item():.4f}')
        print(f'\tTeacher Acc: {teacher_acc.item()*num_samples} / {num_samples} \tStudent Acc: {student_acc.item()*num_samples} / {num_samples}')


    if(epoch%5e3 == 0):
        try:
            s_mpl_preds = np.array([])
            mpl_preds = np.array([])

            for point, _ in renderLoader:
                _, s_mpl_pred = torch.max(supervised_mpl(point.float().to(device)), 1)
                _, mpl_pred = torch.max(metaPseudoLabel(point.float().to(device)), dim=-1)

                s_mpl_preds = np.concatenate((s_mpl_preds,s_mpl_pred.cpu().numpy()))
                mpl_preds = np.concatenate((mpl_preds,mpl_pred.cpu().numpy()))

            s_mpl_pred_points_0 = []
            s_mpl_pred_points_1 = []
            mpl_pred_points_0 = []
            mpl_pred_points_1 = []

            for i in range(arr.shape[0]):
                if(s_mpl_preds[i] == 0):
                    s_mpl_pred_points_0.append(arr[i])
                else:
                    s_mpl_pred_points_1.append(arr[i])
                if(mpl_preds[i] == 0):
                    mpl_pred_points_0.append(arr[i])
                else:
                    mpl_pred_points_1.append(arr[i])

                    gridshape = (2, 2)
            
            gridshape = (1,2)

            loc = (0,0)
            ax = plt.subplot2grid(gridshape, loc)
            ax.set_xticks([])
            ax.set_yticks([])

            plt.scatter(np.array(s_mpl_pred_points_0)[:,0], np.array(s_mpl_pred_points_0)[:,1], s=marker_size)
            plt.scatter(np.array(s_mpl_pred_points_1)[:,0], np.array(s_mpl_pred_points_1)[:,1], s=marker_size)
            plt.scatter(points[:,0],points[:,1])
            plt.scatter(points_l[:,0],points_l[:,1])
            plt.title('Supervised - MPL')

            loc = (0,1)
            ax = plt.subplot2grid(gridshape, loc)
            ax.set_xticks([])
            ax.set_yticks([])

            plt.scatter(np.array(mpl_pred_points_0)[:,0], np.array(mpl_pred_points_0)[:,1], s=marker_size)
            plt.scatter(np.array(mpl_pred_points_1)[:,0], np.array(mpl_pred_points_1)[:,1], s=marker_size)
            plt.scatter(points[:,0],points[:,1])
            plt.scatter(points_l[:,0],points_l[:,1])
            plt.title('Meta Pesudo Label')

            ax.figure.set(figwidth=14, figheight=6)

            plt.show()
        except:
            print('pass')

    epoch += 1


In [None]:
while(False):
# while(epoch < max_epoch):
    for point, label in unlabeledLoader:
        ##### concatenate unlabeled data and labeled data for efficiency
        batch_size = point.shape[0]
        point_noise = point + torch.rand(size=point.shape) * noise
        # point_noise = point + (torch.rand(size=point.shape).numpy()-0.5) * noise
        points_ = torch.concat((point.to(device), point_noise.to(device), points_l_tensor)).float()
        
        # opt_mpl.zero_grad()
        # opt_st.zero_grad()
        
        ##### run teacher model
        # output_t = supervised_mpl(points_)

        # output_t_uo = output_t[:batch_size] # unlabeled original
        # output_t_ua = output_t[batch_size:batch_size*2] # unlabeled augmented
        # output_t_l = output_t[batch_size*2:] # labeled
        # del output_t
        output_t_uo = supervised_mpl(point.to(device))
        
        # loss_t_supervised = creterion(torch.softmax(output_t_l, dim=-1), labels_l_tensor)

        ##### get pseudo_label & compute uda loss
        # # soft_pseudo_label = torch.softmax(output_t_uo, dim=-1)
        # # soft_pseudo_label = torch.softmax(output_t_uo.detach(), dim=-1)
        soft_pseudo_label = torch.softmax(output_t_uo / 0.7, dim=-1).detach() # args.temperture

        max_probs, _ = torch.max(soft_pseudo_label, dim=-1, keepdim=True)
        mask = torch.greater_equal(max_probs, threshold).type(torch.float32)
        # mask = max_probs.ge(threshold).float()  # greater_equal
        # weight_u = lambda_u * min(1., (epoch + 1) / 500.)
        # # loss_t_u = torch.mean(-(soft_pseudo_label * torch.log_softmax(output_t_ua, dim=-1)).sum(dim=-1)) # consistency loss
        loss_t_u = torch.mul(soft_pseudo_label, torch.log_softmax(output_t_ua, dim=-1))
        loss_t_u = torch.mean(torch.mul(-loss_t_u, mask)) # consistency loss
        loss_t_uda = loss_t_u
        # # loss_t_uda = loss_t_u * lambda_u
        # loss_t_uda = loss_t_u * weight_u
        
        # soft_pseudo_label = torch.softmax(output_t_uo.detach() / 0.7, dim=-1)
        # loss_t_uda = creterion(output_t_ua, soft_pseudo_label)

        ##### run student model
        # output = metaPseudoLabel(points_)
        # output_ua = output[batch_size:batch_size*2]
        # output_l = output[batch_size*2:]
        # del output
        output_ua = metaPseudoLabel(point_noise.to(device).float())
        output_l = metaPseudoLabel(points_l_tensor.to(device))

        loss_old_l = creterion(nn.functional.softmax(output_l, dim=-1),
                               labels_l_tensor).detach()

        # student is trained with augmented data 
        # https://github.com/google-research/google-research/issues/534#issuecomment-769559165

        opt_mpl.zero_grad()

        loss = creterion(nn.functional.softmax(output_ua, dim=-1),
                        torch.argmax(soft_pseudo_label, dim=-1).long())
        # loss = creterion(output_ua, torch.softmax(output_t_uo.detach(), dim=-1))    # get loss of student on unlabeled augmented input using pseudo label
        # loss = creterion(output_ua, output_t_uo.detach())    # get loss of student on unlabeled augmented input using pseudo label

        ##### update student
        
        loss.backward()#retain_graph=True)
        opt_mpl.step()

        ##### compute MPL loss with updated student
        output_new_l = metaPseudoLabel(points_l_tensor)
        loss_new_l = creterion(nn.functional.softmax(output_new_l, dim=-1),
                               labels_l_tensor).detach()
                               
        dot_product = (loss_new_l - loss_old_l)
        loss_t_mpl = torch.argmax(nn.functional.softmax(output_t_ua, dim=-1), dim=-1, keepdim=True)
        loss_t_mpl = -torch.mean(torch.mul(loss_t_mpl, nn.functional.log_softmax(output_t_ua, dim=-1)))

        # output_l = metaPseudoLabel(points_l_tensor)
        # loss_t_mpl = creterion(output_l, labels_l_tensor)
        # loss_t_mpl = creterion(output_l.detach(), labels_l_tensor)

        opt_st.zero_grad()
        loss_t = loss_t_mpl + loss_t_supervised + loss_t_uda
        loss_t.backward()
        opt_st.step()

    
    infer = metaPseudoLabel(point.to(device))
    student_acc = sum((torch.argmax(infer, -1) == label.to(device)).type(torch.int))

    if(student_acc/batch_size > best_acc):
            best_acc = student_acc
            best_student_model = copy.deepcopy(metaPseudoLabel)

    if float(student_acc/batch_size) > 0.99:
        break
    
    if(epoch%1e3 == 0):
        _, preds = torch.max(output_t_uo, dim=-1)
        acc_t = torch.sum(preds == label.to(device))
        

        print(f'Epoch: {epoch} \tLoss: {loss.item():.4f} \tTeacher Acc: {acc_t.item()} / {point.shape[0]} \tMPL Acc: {student_acc.item()} / {point.shape[0]}')
        print(f'\t UDA loss: {loss_t_uda:.4f}\t MPL loss: {loss_t_mpl:.4f}\t Supervised loss: {loss_t_supervised:.4f}')


    if(epoch%5e3 == 0):
        try:
            s_mpl_preds = np.array([])
            mpl_preds = np.array([])

            for point, _ in renderLoader:
                _, s_mpl_pred = torch.max(supervised_mpl(point.float().to(device)), 1)
                _, mpl_pred = torch.max(metaPseudoLabel(point.float().to(device)), dim=-1)

                s_mpl_preds = np.concatenate((s_mpl_preds,s_mpl_pred.cpu().numpy()))
                mpl_preds = np.concatenate((mpl_preds,mpl_pred.cpu().numpy()))

            s_mpl_pred_points_0 = []
            s_mpl_pred_points_1 = []
            mpl_pred_points_0 = []
            mpl_pred_points_1 = []

            for i in range(arr.shape[0]):
                if(s_mpl_preds[i] == 0):
                    s_mpl_pred_points_0.append(arr[i])
                else:
                    s_mpl_pred_points_1.append(arr[i])
                if(mpl_preds[i] == 0):
                    mpl_pred_points_0.append(arr[i])
                else:
                    mpl_pred_points_1.append(arr[i])

                    gridshape = (2, 2)
            
            gridshape = (1,2)

            loc = (0,0)
            ax = plt.subplot2grid(gridshape, loc)
            ax.set_xticks([])
            ax.set_yticks([])

            plt.scatter(np.array(s_mpl_pred_points_0)[:,0], np.array(s_mpl_pred_points_0)[:,1], s=marker_size)
            plt.scatter(np.array(s_mpl_pred_points_1)[:,0], np.array(s_mpl_pred_points_1)[:,1], s=marker_size)
            plt.scatter(points[:,0],points[:,1])
            plt.scatter(points_l[:,0],points_l[:,1])
            plt.title('Supervised - MPL')

            loc = (0,1)
            ax = plt.subplot2grid(gridshape, loc)
            ax.set_xticks([])
            ax.set_yticks([])

            plt.scatter(np.array(mpl_pred_points_0)[:,0], np.array(mpl_pred_points_0)[:,1], s=marker_size)
            plt.scatter(np.array(mpl_pred_points_1)[:,0], np.array(mpl_pred_points_1)[:,1], s=marker_size)
            plt.scatter(points[:,0],points[:,1])
            plt.scatter(points_l[:,0],points_l[:,1])
            plt.title('Meta Pesudo Label')

            ax.figure.set(figwidth=14, figheight=6)

            plt.show()
        except:
            print('pass')

    epoch += 1