In [1]:
import gym
from LRL import *

In [2]:
class BatchOfRubiks:
    def __init__(self, groups):
        self.group = groups
        
    def forward_cycle(self, a, b, c, d):
        self.group[:, a-1], self.group[:, b-1], self.group[:, c-1], self.group[:, d-1] = \
        self.group[:, d-1].clone(), self.group[:, a-1].clone(), self.group[:, b-1].clone(), self.group[:, c-1].clone()
        return self
    
    def backward_cycle(self, a, b, c, d):
        self.group[:, a-1], self.group[:, b-1], self.group[:, c-1], self.group[:, d-1] = \
        self.group[:, b-1].clone(), self.group[:, c-1].clone(), self.group[:, d-1].clone(), self.group[:, a-1].clone()
        return self

    def rotate(self, command):
        self.cycle = self.forward_cycle if command < 6 else self.backward_cycle
        command = command % 6
        
        if command == 0:
            self.cycle(1,3,8,6).cycle(2,5,7,4).cycle(33,9,41,32).cycle(36,12,44,29).cycle(38,14,46,27)
        elif command == 1:
            self.cycle(9,11,16,14).cycle(10,13,15,12).cycle(38,17,43,8).cycle(39,20,42,5).cycle(40,22,41,3)
        elif command == 2:
            self.cycle(17,19,24,22).cycle(18,21,23,20).cycle(48,16,40,25).cycle(45,13,37,28).cycle(43,11,35,30)
        elif command == 3:
            self.cycle(25,27,32,30).cycle(26,29,31,28).cycle(19,33,6,48).cycle(21,34,4,47).cycle(24,35,1,46)
        elif command == 4:
            self.cycle(33,35,40,38).cycle(34,37,39,36).cycle(25,17,9,1).cycle(26,18,10,2).cycle(27,19,11,3)
        elif command == 5:
            self.cycle(41,43,48,46).cycle(42,45,47,44).cycle(6,14,22,30).cycle(7,15,23,31).cycle(8,16,24,32)
        return self

def next_states(states):
    ns = torch.cat([BatchOfRubiks(states.clone()).rotate(i).group[:, None] for i in range(12)], dim=1)
    done = ((ns != Tensor([0]*8 + [1]*8 + [2]*8 + [3]*8 + [4]*8 + [5]*8)).sum(dim=2) > 0).type(torch.cuda.FloatTensor)
    return ns, Tensor(states.size()[0], 12).zero_() - 1, 1 - done

In [3]:
class Rubik:
    def __init__(self, complexity=1, limit=50):
        self.complexity = complexity
        self.limit = limit
        self.action_space = gym.spaces.Discrete(12)
        self.init_state = [0]*8 + [1]*8 + [2]*8 + [3]*8 + [4]*8 + [5]*8
        self.steps = 0
        self.next_states_function = next_states
        self.reset()
    
    def reset(self):
        self.group = self.init_state[:]
        
        while self.group == self.init_state:
            for _ in range(self.complexity):
                self.step(np.random.randint(0, 12))
        
        self.steps = 0
        return self.group
    
    def forward_cycle(self, a, b, c, d):
        self.group[a-1], self.group[b-1], self.group[c-1], self.group[d-1] = self.group[d-1], self.group[a-1], self.group[b-1], self.group[c-1]
        return self
    
    def backward_cycle(self, a, b, c, d):
        self.group[a-1], self.group[b-1], self.group[c-1], self.group[d-1] = self.group[b-1], self.group[c-1], self.group[d-1], self.group[a-1]
        return self
        
    def rotate(self, command):
        # решения через циклические сдвиги в двумерном массиве без нумпая привели к экспоненциальному росту костылей в коде.
        # кажется, это наиболее лаконичный и бескровный вариант
        if command == 0:
            self.cycle(1,3,8,6).cycle(2,5,7,4).cycle(33,9,41,32).cycle(36,12,44,29).cycle(38,14,46,27)
        elif command == 1:
            self.cycle(9,11,16,14).cycle(10,13,15,12).cycle(38,17,43,8).cycle(39,20,42,5).cycle(40,22,41,3)
        elif command == 2:
            self.cycle(17,19,24,22).cycle(18,21,23,20).cycle(48,16,40,25).cycle(45,13,37,28).cycle(43,11,35,30)
        elif command == 3:
            self.cycle(25,27,32,30).cycle(26,29,31,28).cycle(19,33,6,48).cycle(21,34,4,47).cycle(24,35,1,46)
        elif command == 4:
            self.cycle(33,35,40,38).cycle(34,37,39,36).cycle(25,17,9,1).cycle(26,18,10,2).cycle(27,19,11,3)
        elif command == 5:
            self.cycle(41,43,48,46).cycle(42,45,47,44).cycle(6,14,22,30).cycle(7,15,23,31).cycle(8,16,24,32)
    
    def step(self, command):
        self.cycle = self.forward_cycle if command < 6 else self.backward_cycle
        self.rotate(command % 6)
        
        self.steps += 1        
        done = self.group == self.init_state
        died = self.steps > self.limit
        return self.group, -1, done, {"died": died}      
                
    def __str__(self):
        return '.'*3 + ''.join(map(str, self.group[32:35])) + '.'*6 + '\n' +\
               '.'*3 + str(self.group[35]) + 'U' + str(self.group[36]) + '.'*6 + '\n' +\
               '.'*3 + ''.join(map(str, self.group[37:40])) + '.'*6 + '\n' +\
               ''.join(map(str, self.group[0:3] + self.group[8:11] + self.group[16:19] + self.group[24:27])) + '\n' +\
               str(self.group[3]) + 'L' + str(self.group[4]) + str(self.group[11]) + 'F' + str(self.group[12]) + str(self.group[19]) + 'R' + str(self.group[20]) + str(self.group[27]) + 'B' + str(self.group[28]) + '\n' +\
               ''.join(map(str, self.group[5:8] + self.group[13:16] + self.group[21:24] + self.group[29:32])) + '\n' +\
               '.'*3 + ''.join(map(str, self.group[40:43])) + '.'*6 + '\n' +\
               '.'*3 + str(self.group[43]) + 'D' + str(self.group[44]) + '.'*6 + '\n' +\
               '.'*3 + ''.join(map(str, self.group[45:48])) + '.'*6 + '\n'

In [4]:
rubik = Rubik()



In [5]:
print(rubik)

...444......
...4U4......
...444......
333000111222
0L01F12R23B3
000111222333
...555......
...5D5......
...555......



In [6]:
R = Tensor(rubik.group)
R

tensor([3., 3., 3., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.,
        1., 2., 2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3., 4., 4., 4., 4.,
        4., 4., 4., 4., 5., 5., 5., 5., 5., 5., 5., 5.], device='cuda:0')

In [7]:
RotateR = Tensor([
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0]])
RotateL = Tensor([
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0]
        ])

In [8]:
print(rubik)

...444......
...4U4......
...444......
333000111222
0L01F12R23B3
000111222333
...555......
...5D5......
...555......



In [9]:
def rotateX(cube):
    L, F, R, B, U, D = cube[:8], cube[8:16], cube[16:24], cube[24:32], cube[32:40], cube[40:48]
    return torch.cat([
        RotateR.mv(D),
        RotateR.mv(F),
        RotateR.mv(U),
        RotateL.mv(B),
        RotateR.mv(L),
        RotateR.mv(R)
    ])
def rotateY(cube):
    L, F, R, B, U, D = cube[:8], cube[8:16], cube[16:24], cube[24:32], cube[32:40], cube[40:48]
    return torch.cat([
        RotateL.mv(L),
        D,
        RotateR.mv(R),
        RotateL.mv(RotateL.mv(U)),
        F,
        RotateL.mv(RotateL.mv(B))
    ])
def rotateZ(cube):
    L, F, R, B, U, D = cube[:8], cube[8:16], cube[16:24], cube[24:32], cube[32:40], cube[40:48]
    return torch.cat([
        F,
        R,
        B,
        L,
        RotateR.mv(U),
        RotateL.mv(D)
    ])

In [10]:
rubik.group = rotateZ(Tensor(rubik.group)).cpu().data.numpy().astype(np.int).tolist()

In [11]:
print(rubik)

...444......
...4U4......
...444......
000111222333
1L12F23R30B0
111222333000
...555......
...5D5......
...555......



In [12]:
def one_hot(x, nb_digits):
    x_onehot = Tensor(x.size()[0], nb_digits)
    x_onehot.zero_()
    x_onehot.scatter_(1, x.view(-1, 1).type(torch.cuda.LongTensor), 1)
    return x_onehot.view(x.size()[0], nb_digits)

In [13]:
RotateX = one_hot(rotateX(Tensor(np.arange(48))), 48)
RotateY = one_hot(rotateY(Tensor(np.arange(48))), 48)
RotateZ = one_hot(rotateZ(Tensor(np.arange(48))), 48)

In [14]:
INVAIANTS = torch.cat([
    RotateX.mm(RotateX.mm(RotateX.mm(RotateX)))[None],
    RotateX[None],
    RotateX.mm(RotateX)[None],
    RotateX.mm(RotateX.mm(RotateX))[None],
    RotateY[None],
    RotateY.mm(RotateX)[None],
    RotateY.mm(RotateX.mm(RotateX))[None],
    RotateY.mm(RotateX.mm(RotateX.mm(RotateX)))[None],
    RotateY.mm(RotateY)[None],
    RotateY.mm(RotateY.mm(RotateX))[None],
    RotateY.mm(RotateY.mm(RotateX.mm(RotateX)))[None],
    RotateY.mm(RotateY.mm(RotateX.mm(RotateX.mm(RotateX))))[None],
    RotateY.mm(RotateY.mm(RotateY))[None],
    RotateY.mm(RotateY.mm(RotateY.mm(RotateX)))[None],
    RotateY.mm(RotateY.mm(RotateY.mm(RotateX.mm(RotateX))))[None],
    RotateY.mm(RotateY.mm(RotateY.mm(RotateX.mm(RotateX.mm(RotateX)))))[None],
    RotateZ[None],
    RotateZ.mm(RotateX)[None],
    RotateZ.mm(RotateX.mm(RotateX))[None],
    RotateZ.mm(RotateX.mm(RotateX.mm(RotateX)))[None],
    RotateZ.mm(RotateZ.mm(RotateZ))[None],
    RotateZ.mm(RotateZ.mm(RotateZ.mm(RotateX)))[None],
    RotateZ.mm(RotateZ.mm(RotateZ.mm(RotateX.mm(RotateX))))[None],
    RotateZ.mm(RotateZ.mm(RotateZ.mm(RotateX.mm(RotateX.mm(RotateX)))))[None]
], dim=0
)

In [19]:
rubik = Rubik()
for g in INVAIANTS.matmul(Tensor(rubik.group)).cpu().data.numpy():
    rubik.group = g.astype(int).tolist()
    print(rubik)

...444......
...4U4......
...444......
111222333000
0L01F12R23B3
000111222333
...555......
...5D5......
...555......

...001......
...0U1......
...001......
555112444033
5L51F24R40B3
555112444033
...223......
...2D3......
...223......

...555......
...5U5......
...555......
222111000333
2L21F10R03B3
333222111000
...444......
...4D4......
...444......

...322......
...3U2......
...322......
444211555330
4L42F15R53B0
444211555330
...100......
...1D0......
...100......

...222......
...1U1......
...111......
100555223444
1L05F52R34B4
100555223444
...333......
...3D3......
...000......

...112......
...1U2......
...112......
555223444100
5L52F34R41B0
555223444100
...330......
...3D0......
...330......

...111......
...1U1......
...222......
223444100555
2L34F41R05B5
223444100555
...000......
...3D3......
...333......

...211......
...2U1......
...211......
444100555223
4L41F05R52B3
444100555223
...033......
...0D3......
...033......

...555......
...5U5......
...555......
000333222111
0L03

In [22]:
INVAIANTS.cpu().data.numpy().save?

In [25]:
np.save("RubikInvariantsMatrix.npy", INVAIANTS.cpu().data.numpy())