## 状态压缩

In [None]:
n = 120
k = 5
print(bin(n))
# 取第 k 位
print((n >> k) & 1)
# 取后 k 位
print(bin(n & ((1<<k) - 1)))
# 取前 k 位 
print(bin(n >> (n.bit_length() - k)))

# 第 k 位取反
print(bin(n ^ (1<<k)))
# 第 k 位赋 1 
print(bin(n | (1<<k)))
# 第 k 位赋 0
print(bin(n & (~(1<<k))))

## 用二进制表示数位集合

### 要求集合中不能有两个相邻的元素

In [None]:
mask = 0b101010010
print ((mask >> 1) & mask)
mask = 0b101011010
print ((mask >> 1) & mask)

### 子集枚举： 0 位与mask一致的所有元素

In [None]:
m = mask
while m > 0:
    print(bin(m))
    m = (m-1) & mask

### 子集枚举：1位与mask一致的所有元素

In [None]:
m = mask
while m <= (1<<mask.bit_length()):
    print(bin(m))
    m = (m+1) | mask

### 枚举恰好有 k 个 1 的所有元素

In [None]:

n = 10
m = 0b111111
while m < (1<<n):
    t = m & -m
    print(bin(m))
    m = (m+t) | (((m ^ (m+t)) >> 2)//t)

## 其他

### 判断是否为 4 的幂次

In [None]:
def is_power_of_four(n):
    return  not bool(n & (n-1)) and bool(n&0x55555555)

is_power_of_four(64)

### 判断0~n中丢失的一个数值

In [None]:
ret = 0
nums = [0, 1, 2, 3, 4, 6, 7, 8, 9]
for i, x in enumerate(nums):
    ret ^= x
    ret ^= i
ret ^ len(nums)

### 求区间内的AND值

In [None]:
def bitrange(m, n):
    a = 0
    while m != n:
        m >>= 1
        n >>= 1
        a += 1
    return m<<a

bitrange(5, 7)

## BitSet

In [None]:
DIV = 63
class BitSet:
    @staticmethod
    def get_bucket_size(x):
        return ((x-1) // DIV) + 1

    @staticmethod
    def popcount(x):
        x = x - ((x >> 1) & 0x5555555555555555)
        x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)
        x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f
        x = x + (x >> 8)
        x = x + (x >> 16)
        x = x + (x >> 32)
        return x & 0x0000007f
    

    def __init__(self, n):
        self.n = n
        self.buckets = [0] * self.get_bucket_size(n)

    def in_range(self, p):
        return 0 <= p and p < self.n

    def set(self, p, val = True):
        assert self.in_range(p)
        if val: self.buckets[p // DIV] |= 1 << (p % DIV)
        elif self.test(p): self.flip(p)
            
    def test(self, p):
        assert self.in_range(p)
        return self.buckets[p // DIV] >> (p % DIV) & 1
    
    def flip(self, p = None):
        if p is None:
            for i in range(self.n):
                self.flip(i)
        else:
            assert self.in_range(p)
            self.buckets[p // DIV] ^= 1 << (p % DIV)

    def any(self):
        for mask in self.buckets:
            if mask: return True
        return False
    
    def count(self):
        ret = 0
        for mask in self.buckets:
            ret += self.popcount(mask)
        return ret
    
    def lowbit(self):
        for i, x in enumerate(self.buckets):
            if x > 0: return (i * DIV) + (x & -x).bit_length()-1
    
    def touch(self, l, r):
        L, R = l // DIV, r // DIV
        for i in range(DIV):
            if (self.buckets[L] >> i) & 1:
                w = L * DIV + i
                if l <= w and w <= r: return w
        for i, x in enumerate(self.buckets[L+1, R]):
            if x: return i * DIV + (x & -x).bit_length() - 1
        for i in range(DIV):
            if (self.buckets[R] >> i) & i:
                w = R * DIV + i
                if l <= w and w <= r: return w
        return -1

    def resize(self, m):
        newsz = self.get_bucket_size(m)
        sz = len(self.buckets)
        if newsz < sz:
            self.buckets = self.buckets[0:sz]
        else:
            self.buckets += [0] * (newsz-sz)
        self.n = m
        return None

    def size(self):
        return self.n
    
    def __len__(self):
        return self.n
    
    def __and__(self, rhs):
        ret = BitSet(max(self.size(), rhs.size()))
        m = min(len(self.buckets), len(rhs.buckets))
        for i in range(m):
            ret.buckets[i] = self.buckets[i] & rhs.buckets[i]
        return ret
    
    def __or__(self, rhs):
        ret = BitSet(max(self.size(), rhs.size()))
        for i in range(len(ret.buckets)):
            if i < len(self.buckets): ret.buckets[i] |= self.buckets[i]
            if i < len(rhs.buckets): ret.buckets[i] |= rhs.buckets[i]
        return ret

    def __xor__(self, rhs):
        ret = BitSet(max(self.size(), rhs.size()))
        for i in range(len(ret.buckets)):
            if i < len(self.buckets): ret.buckets[i] ^= self.buckets[i]
            if i < len(rhs.buckets): ret.buckets[i] ^= rhs.buckets[i]
        return ret
    
    def __invert__(self):
        self.flip()
    
    def __ilshift__(self, m):
        p, w = m // DIV, m % DIV
        for i in range(len(self.buckets)-1, -1, -1):
            if w == 0:
                self.buckets[i] = 0 if i-p < 0 else self.buckets[i-p] 
            else:
                a = 0 if i-p-1 < 0 else self.buckets[i-p-1] >> (DIV - w) 
                b = 0 if i-p < 0 else self.buckets[i-p] << w 
                self.buckets[i] = (a | b) % (1 << DIV)
        return self
    
    def __irshift__(self, m):
        p, w = m // DIV, m % DIV
        n = len(self.buckets)
        for i in range(n):
            a = self.buckets[i+p] >> w if i + p < n else 0
            b = self.buckets[i+p+1] << (DIV - w) if i+p+1 < n and w > 0 else 0
            self.buckets[i] = (a | b) % (1 << DIV)
        return self
            
    
    def __str__(self):
        ret = []
        for i, b in enumerate(self.buckets):
            tmp = str(bin(b))[2:][::-1]
            while len(tmp) + i * DIV < min(self.n, (i+1)*DIV):
                tmp += '0'
            ret.append(tmp)
        return ''.join(ret)
    
    def __getitem__(self, p):
        return self.test(p)
    
    def __setitem__(self, p, v):
        self.set(p, v)
        
    def __hash__(self):
        return hash(tuple(self.buckets))
    
    def __iter__(self):
        return iter(tuple(self.buckets))



In [None]:
from random import randint
n = 80
s = ''
for i in range(n):
    s += str(randint(0, 1))

b = BitSet(n)
a = BitSet(n)
print(s)
for i in range(n):
    if s[i] == '1':
        b.set(i)
    a.set(i, b.test(i))

b.flip()

print(b)
print(a)
print(b.any())
print(b.count())
print(a ^ b)

In [None]:
print(s)
a >>= 10
print(a)
print(b)

In [None]:
for x in a:
    print(bin(x))

In [None]:
print(a.touch(0, 5))
print(a.touch(1, 5))